diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ed0dd4f71d27..e6b8ed4e6687 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -612,8 +612,9 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` for + out_dim==4 or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3 where `heads` is the number of heads + initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. @@ -623,6 +624,8 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten Returns: `torch.Tensor`: The reshaped tensor. """ + if out_dim not in {3, 4}: + raise ValueError(f"Expected `out_dim` to be 3 or 4, got {out_dim}.") head_size = self.heads if tensor.ndim == 3: batch_size, seq_len, dim = tensor.shape @@ -630,9 +633,9 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten else: batch_size, extra_dim, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: + tensor = tensor.permute(0, 2, 1, 3) tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) return tensor