@@ -612,8 +612,10 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
612612
613613 def head_to_batch_dim (self , tensor : torch .Tensor , out_dim : int = 3 ) -> torch .Tensor :
614614 r"""
615- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
616- the number of heads initialized while constructing the `Attention` class.
615+ Reshape the tensor from `[batch_size, seq_len, dim]` to
616+ `[batch_size, seq_len, heads, dim // heads]` for out_dim==4
617+ or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3
618+ where `heads` is the number of heads initialized while constructing the `Attention` class.
617619
618620 Args:
619621 tensor (`torch.Tensor`): The tensor to reshape.
@@ -630,9 +632,10 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
630632 else :
631633 batch_size , extra_dim , seq_len , dim = tensor .shape
632634 tensor = tensor .reshape (batch_size , seq_len * extra_dim , head_size , dim // head_size )
633- tensor = tensor . permute ( 0 , 2 , 1 , 3 )
634-
635+
636+ assert out_dim in [ 3 , 4 ]
635637 if out_dim == 3 :
638+ tensor = tensor .permute (0 , 2 , 1 , 3 )
636639 tensor = tensor .reshape (batch_size * head_size , seq_len * extra_dim , dim // head_size )
637640
638641 return tensor
0 commit comments