diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index cdebad43..2872d4b4 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -362,7 +362,7 @@ def forward(self, **kwargs, ): t = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep)) + sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context)