From 9b24fb5d9ac5456b513278a0d3f5e3eba8ce7c98 Mon Sep 17 00:00:00 2001 From: Kaifeng Gao Date: Fri, 20 Dec 2024 05:40:09 +0000 Subject: [PATCH 1/3] fix bug of Attention.head_to_batch_dim issue #10303 --- src/diffusers/models/attention_processor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ed0dd4f71d27..fb21f69c5a6f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -612,8 +612,10 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. + Reshape the tensor from `[batch_size, seq_len, dim]` to + `[batch_size, seq_len, heads, dim // heads]` for out_dim==4 + or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3 + where `heads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. @@ -630,9 +632,10 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten else: batch_size, extra_dim, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - + + assert out_dim in [3,4] if out_dim == 3: + tensor = tensor.permute(0, 2, 1, 3) tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) return tensor From bc71e638e605d1e5f900cf5c1e24742439d3d18f Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 20 Dec 2024 11:59:49 +0000 Subject: [PATCH 2/3] Update attention_processor.py --- src/diffusers/models/attention_processor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fb21f69c5a6f..7d0f37a2b9b4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -612,7 +612,7 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` for out_dim==4 or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3 where `heads` is the number of heads initialized while constructing the `Attention` class. @@ -625,6 +625,8 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten Returns: `torch.Tensor`: The reshaped tensor. """ + if out_dim not in {3, 4}: + raise ValueError(f"Expected `out_dim` to be 3 or 4, got {out_dim}.") head_size = self.heads if tensor.ndim == 3: batch_size, seq_len, dim = tensor.shape @@ -632,8 +634,7 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten else: batch_size, extra_dim, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - - assert out_dim in [3,4] + if out_dim == 3: tensor = tensor.permute(0, 2, 1, 3) tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) From 7f84b000ba28484feaaf4a7edb6c495ae985506c Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 20 Dec 2024 12:04:37 +0000 Subject: [PATCH 3/3] make style --- src/diffusers/models/attention_processor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7d0f37a2b9b4..e6b8ed4e6687 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -612,10 +612,9 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to - `[batch_size, seq_len, heads, dim // heads]` for out_dim==4 - or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3 - where `heads` is the number of heads initialized while constructing the `Attention` class. + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` for + out_dim==4 or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3 where `heads` is the number of heads + initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape.