Skip to content

Commit 9b24fb5

Browse files
committed
fix bug of Attention.head_to_batch_dim issue #10303
1 parent 41ba8c0 commit 9b24fb5

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,10 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
612612

613613
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
614614
r"""
615-
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
616-
the number of heads initialized while constructing the `Attention` class.
615+
Reshape the tensor from `[batch_size, seq_len, dim]` to
616+
`[batch_size, seq_len, heads, dim // heads]` for out_dim==4
617+
or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3
618+
where `heads` is the number of heads initialized while constructing the `Attention` class.
617619
618620
Args:
619621
tensor (`torch.Tensor`): The tensor to reshape.
@@ -630,9 +632,10 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
630632
else:
631633
batch_size, extra_dim, seq_len, dim = tensor.shape
632634
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
633-
tensor = tensor.permute(0, 2, 1, 3)
634-
635+
636+
assert out_dim in [3,4]
635637
if out_dim == 3:
638+
tensor = tensor.permute(0, 2, 1, 3)
636639
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
637640

638641
return tensor

0 commit comments

Comments
 (0)