@@ -612,7 +612,7 @@ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
612612
613613 def head_to_batch_dim (self , tensor : torch .Tensor , out_dim : int = 3 ) -> torch .Tensor :
614614 r"""
615- Reshape the tensor from `[batch_size, seq_len, dim]` to
615+ Reshape the tensor from `[batch_size, seq_len, dim]` to
616616 `[batch_size, seq_len, heads, dim // heads]` for out_dim==4
617617 or `[batch_size * heads, seq_len, dim // heads]` for out_dim==3
618618 where `heads` is the number of heads initialized while constructing the `Attention` class.
@@ -625,15 +625,16 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
625625 Returns:
626626 `torch.Tensor`: The reshaped tensor.
627627 """
628+ if out_dim not in {3 , 4 }:
629+ raise ValueError (f"Expected `out_dim` to be 3 or 4, got { out_dim } ." )
628630 head_size = self .heads
629631 if tensor .ndim == 3 :
630632 batch_size , seq_len , dim = tensor .shape
631633 extra_dim = 1
632634 else :
633635 batch_size , extra_dim , seq_len , dim = tensor .shape
634636 tensor = tensor .reshape (batch_size , seq_len * extra_dim , head_size , dim // head_size )
635-
636- assert out_dim in [3 ,4 ]
637+
637638 if out_dim == 3 :
638639 tensor = tensor .permute (0 , 2 , 1 , 3 )
639640 tensor = tensor .reshape (batch_size * head_size , seq_len * extra_dim , dim // head_size )
0 commit comments