diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2e902db7ffc7..eaeb697c64c0 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -1048,7 +1048,9 @@ def load_model_hook(models, input_dir): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 877ca6135849..ae627bb3a04c 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1210,7 +1210,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)