From e713b4374fd907277550366a0399d9472b98727c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 21:03:59 +0100 Subject: [PATCH] fix joint pos embedding device --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b423c17c1246..a3431a5902de 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -692,7 +692,7 @@ def _get_positional_embeddings( output_type="pt", ) pos_embedding = pos_embedding.flatten(0, 1) - joint_pos_embedding = torch.zeros( + joint_pos_embedding = pos_embedding.new_zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)