Skip to content

Commit bc71e63

Browse files
authored
Update attention_processor.py
1 parent 9b24fb5 commit bc71e63

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
612612

613613
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
614614
r"""
615-
Reshape the tensor from `[batch_size, seq_len, dim]` to
615+
Reshape the tensor from `[batch_size, seq_len, dim]` to
616616
`[batch_size, seq_len, heads, dim // heads]` for out_dim==4
617617
or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3
618618
where `heads` is the number of heads initialized while constructing the `Attention` class.
@@ -625,15 +625,16 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
625625
Returns:
626626
`torch.Tensor`: The reshaped tensor.
627627
"""
628+
if out_dim not in {3, 4}:
629+
raise ValueError(f"Expected `out_dim` to be 3 or 4, got {out_dim}.")
628630
head_size = self.heads
629631
if tensor.ndim == 3:
630632
batch_size, seq_len, dim = tensor.shape
631633
extra_dim = 1
632634
else:
633635
batch_size, extra_dim, seq_len, dim = tensor.shape
634636
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
635-
636-
assert out_dim in [3,4]
637+
637638
if out_dim == 3:
638639
tensor = tensor.permute(0, 2, 1, 3)
639640
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)

0 commit comments

Comments
 (0)