Skip to content

Commit 2b5b2e3

Browse files
committed
fix dtype casting in timestep guidance module.
1 parent 4082c43 commit 2b5b2e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,10 @@ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool
523523

524524
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
525525
timesteps_proj = self.time_proj(timestep)
526-
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
526+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
527527

528528
guidance_proj = self.time_proj(guidance)
529-
guidance_emb = self.guidance_embedder(guidance_proj) # (N, D)
529+
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
530530

531531
time_guidance_emb = timesteps_emb + guidance_emb
532532

0 commit comments

Comments
 (0)