Skip to content

Commit 2ec9ffa

Browse files
Andreas JörgAndreas Jörg
authored andcommitted
Fix: dtype mismatch of prompt embeddings in sd3 controlnet training
1 parent 20e4b6a commit 2ec9ffa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12831283
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
12841284

12851285
# Get the text embedding for conditioning
1286-
prompt_embeds = batch["prompt_embeds"]
1287-
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1286+
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
1287+
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
12881288

12891289
# controlnet(s) inference
12901290
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

0 commit comments

Comments
 (0)