diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4558d48edad9..1768c81ce039 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -748,10 +748,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): pos_embedding = self._get_positional_embeddings( height, width, pre_time_compression_frames, device=embeds.device ) - pos_embedding = pos_embedding.to(dtype=embeds.dtype) else: pos_embedding = self.pos_embedding + pos_embedding = pos_embedding.to(dtype=embeds.dtype) embeds = embeds + pos_embedding return embeds