Skip to content

Commit e713b43

Browse files
committed
fix joint pos embedding device
1 parent f9d5a93 commit e713b43

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def _get_positional_embeddings(
692692
output_type="pt",
693693
)
694694
pos_embedding = pos_embedding.flatten(0, 1)
695-
joint_pos_embedding = torch.zeros(
695+
joint_pos_embedding = pos_embedding.new_zeros(
696696
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
697697
)
698698
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)

0 commit comments

Comments
 (0)