Skip to content

Commit ed14d26

Browse files
committed
fix UNet1DModelTests::test_layerwise_upcasting_inference
1 parent 4450b1c commit ed14d26

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/unets/unet_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def forward(
225225

226226
timestep_embed = self.time_proj(timesteps)
227227
if self.config.use_timestep_embedding:
228-
timestep_embed = self.time_mlp(timestep_embed)
228+
timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
229229
else:
230230
timestep_embed = timestep_embed[..., None]
231231
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)

0 commit comments

Comments
 (0)