@@ -913,38 +913,21 @@ def patchify(x, patch_size):
913913    if  patch_size  ==  1 :
914914        return  x 
915915
916-     if  x .dim () ==  4 :
917-         # x shape: [batch_size, channels, height, width] 
918-         batch_size , channels , height , width  =  x .shape 
919- 
920-         # Ensure height and width are divisible by patch_size 
921-         if  height  %  patch_size  !=  0  or  width  %  patch_size  !=  0 :
922-             raise  ValueError (f"Height ({ height }  ) and width ({ width }  ) must be divisible by patch_size ({ patch_size }  )" )
923- 
924-         # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size] 
925-         x  =  x .view (batch_size , channels , height  //  patch_size , patch_size , width  //  patch_size , patch_size )
926- 
927-         # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size] 
928-         x  =  x .permute (0 , 1 , 3 , 5 , 2 , 4 ).contiguous ()
929-         x  =  x .view (batch_size , channels  *  patch_size  *  patch_size , height  //  patch_size , width  //  patch_size )
930- 
931-     elif  x .dim () ==  5 :
932-         # x shape: [batch_size, channels, frames, height, width] 
933-         batch_size , channels , frames , height , width  =  x .shape 
934- 
935-         # Ensure height and width are divisible by patch_size 
936-         if  height  %  patch_size  !=  0  or  width  %  patch_size  !=  0 :
937-             raise  ValueError (f"Height ({ height }  ) and width ({ width }  ) must be divisible by patch_size ({ patch_size }  )" )
916+     if  x .dim () !=  5 :
917+         raise  ValueError (f"Invalid input shape: { x .shape }  " )
918+     # x shape: [batch_size, channels, frames, height, width] 
919+     batch_size , channels , frames , height , width  =  x .shape 
938920
939-         # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] 
940-         x  =  x .view (batch_size , channels , frames , height  //  patch_size , patch_size , width  //  patch_size , patch_size )
921+     # Ensure height and width are divisible by patch_size 
922+     if  height  %  patch_size  !=  0  or  width  %  patch_size  !=  0 :
923+         raise  ValueError (f"Height ({ height }  ) and width ({ width }  ) must be divisible by patch_size ({ patch_size }  )" )
941924
942-         # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] 
943-         x  =  x .permute (0 , 1 , 4 , 6 , 2 , 3 , 5 ).contiguous ()
944-         x  =  x .view (batch_size , channels  *  patch_size  *  patch_size , frames , height  //  patch_size , width  //  patch_size )
925+     # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] 
926+     x  =  x .view (batch_size , channels , frames , height  //  patch_size , patch_size , width  //  patch_size , patch_size )
945927
946-     else :
947-         raise  ValueError (f"Invalid input shape: { x .shape }  " )
928+     # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] 
929+     x  =  x .permute (0 , 1 , 6 , 4 , 2 , 3 , 5 ).contiguous ()
930+     x  =  x .view (batch_size , channels  *  patch_size  *  patch_size , frames , height  //  patch_size , width  //  patch_size )
948931
949932    return  x 
950933
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
953936    if  patch_size  ==  1 :
954937        return  x 
955938
956-     if  x .dim () ==  4 :
957-         # x shape: [b, (c * patch_size * patch_size), h, w] 
958-         batch_size , c_patches , height , width  =  x .shape 
959-         channels  =  c_patches  //  (patch_size  *  patch_size )
960- 
961-         # Reshape to [b, c, patch_size, patch_size, h, w] 
962-         x  =  x .view (batch_size , channels , patch_size , patch_size , height , width )
963- 
964-         # Rearrange to [b, c, h * patch_size, w * patch_size] 
965-         x  =  x .permute (0 , 1 , 4 , 2 , 5 , 3 ).contiguous ()
966-         x  =  x .view (batch_size , channels , height  *  patch_size , width  *  patch_size )
967- 
968-     elif  x .dim () ==  5 :
969-         # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] 
970-         batch_size , c_patches , frames , height , width  =  x .shape 
971-         channels  =  c_patches  //  (patch_size  *  patch_size )
939+     if  x .dim () !=  5 :
940+         raise  ValueError (f"Invalid input shape: { x .shape }  " )
941+     # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] 
942+     batch_size , c_patches , frames , height , width  =  x .shape 
943+     channels  =  c_patches  //  (patch_size  *  patch_size )
972944
973-          # Reshape to [b, c, patch_size, patch_size, f, h, w] 
974-          x  =  x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
945+     # Reshape to [b, c, patch_size, patch_size, f, h, w] 
946+     x  =  x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
975947
976-          # Rearrange to [b, c, f, h * patch_size, w * patch_size] 
977-          x  =  x .permute (0 , 1 , 4 , 5 , 2 , 6 , 3 ).contiguous ()
978-          x  =  x .view (batch_size , channels , frames , height  *  patch_size , width  *  patch_size )
948+     # Rearrange to [b, c, f, h * patch_size, w * patch_size] 
949+     x  =  x .permute (0 , 1 , 4 , 5 , 3 , 6 , 2 ).contiguous ()
950+     x  =  x .view (batch_size , channels , frames , height  *  patch_size , width  *  patch_size )
979951
980952    return  x 
981953
@@ -1044,7 +1016,6 @@ def __init__(
10441016        patch_size : Optional [int ] =  None ,
10451017        scale_factor_temporal : Optional [int ] =  4 ,
10461018        scale_factor_spatial : Optional [int ] =  8 ,
1047-         clip_output : bool  =  True ,
10481019    ) ->  None :
10491020        super ().__init__ ()
10501021
@@ -1244,10 +1215,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12441215                out_  =  self .decoder (x [:, :, i  : i  +  1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
12451216                out  =  torch .cat ([out , out_ ], 2 )
12461217
1247-         if  self .config .clip_output :
1248-             out  =  torch .clamp (out , min = - 1.0 , max = 1.0 )
12491218        if  self .config .patch_size  is  not   None :
12501219            out  =  unpatchify (out , patch_size = self .config .patch_size )
1220+ 
1221+         out  =  torch .clamp (out , min = - 1.0 , max = 1.0 )
1222+ 
12511223        self .clear_cache ()
12521224        if  not  return_dict :
12531225            return  (out ,)
0 commit comments