Skip to content

Commit 1c47d1f

Browse files
authored
Fix head_to_batch_dim for IPAdapterAttnProcessor (#7077)
* Fix IPAdapterAttnProcessor * Fix batch_to_head_dim and revert reshape
1 parent bbf70c8 commit 1c47d1f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,16 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
559559
`torch.Tensor`: The reshaped tensor.
560560
"""
561561
head_size = self.heads
562-
batch_size, seq_len, dim = tensor.shape
563-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
562+
if tensor.ndim == 3:
563+
batch_size, seq_len, dim = tensor.shape
564+
extra_dim = 1
565+
else:
566+
batch_size, extra_dim, seq_len, dim = tensor.shape
567+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
564568
tensor = tensor.permute(0, 2, 1, 3)
565569

566570
if out_dim == 3:
567-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
571+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
568572

569573
return tensor
570574

0 commit comments

Comments
 (0)