@@ -913,38 +913,21 @@ def patchify(x, patch_size):
913913 if patch_size == 1 :
914914 return x
915915
916- if x .dim () == 4 :
917- # x shape: [batch_size, channels, height, width]
918- batch_size , channels , height , width = x .shape
919-
920- # Ensure height and width are divisible by patch_size
921- if height % patch_size != 0 or width % patch_size != 0 :
922- raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
923-
924- # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
925- x = x .view (batch_size , channels , height // patch_size , patch_size , width // patch_size , patch_size )
926-
927- # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
928- x = x .permute (0 , 1 , 3 , 5 , 2 , 4 ).contiguous ()
929- x = x .view (batch_size , channels * patch_size * patch_size , height // patch_size , width // patch_size )
930-
931- elif x .dim () == 5 :
932- # x shape: [batch_size, channels, frames, height, width]
933- batch_size , channels , frames , height , width = x .shape
934-
935- # Ensure height and width are divisible by patch_size
936- if height % patch_size != 0 or width % patch_size != 0 :
937- raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
916+ if x .dim () != 5 :
917+ raise ValueError (f"Invalid input shape: { x .shape } " )
918+ # x shape: [batch_size, channels, frames, height, width]
919+ batch_size , channels , frames , height , width = x .shape
938920
939- # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
940- x = x .view (batch_size , channels , frames , height // patch_size , patch_size , width // patch_size , patch_size )
921+ # Ensure height and width are divisible by patch_size
922+ if height % patch_size != 0 or width % patch_size != 0 :
923+ raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
941924
942- # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
943- x = x .permute (0 , 1 , 4 , 6 , 2 , 3 , 5 ).contiguous ()
944- x = x .view (batch_size , channels * patch_size * patch_size , frames , height // patch_size , width // patch_size )
925+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926+ x = x .view (batch_size , channels , frames , height // patch_size , patch_size , width // patch_size , patch_size )
945927
946- else :
947- raise ValueError (f"Invalid input shape: { x .shape } " )
928+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929+ x = x .permute (0 , 1 , 6 , 4 , 2 , 3 , 5 ).contiguous ()
930+ x = x .view (batch_size , channels * patch_size * patch_size , frames , height // patch_size , width // patch_size )
948931
949932 return x
950933
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
953936 if patch_size == 1 :
954937 return x
955938
956- if x .dim () == 4 :
957- # x shape: [b, (c * patch_size * patch_size), h, w]
958- batch_size , c_patches , height , width = x .shape
959- channels = c_patches // (patch_size * patch_size )
960-
961- # Reshape to [b, c, patch_size, patch_size, h, w]
962- x = x .view (batch_size , channels , patch_size , patch_size , height , width )
963-
964- # Rearrange to [b, c, h * patch_size, w * patch_size]
965- x = x .permute (0 , 1 , 4 , 2 , 5 , 3 ).contiguous ()
966- x = x .view (batch_size , channels , height * patch_size , width * patch_size )
967-
968- elif x .dim () == 5 :
969- # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
970- batch_size , c_patches , frames , height , width = x .shape
971- channels = c_patches // (patch_size * patch_size )
939+ if x .dim () != 5 :
940+ raise ValueError (f"Invalid input shape: { x .shape } " )
941+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942+ batch_size , c_patches , frames , height , width = x .shape
943+ channels = c_patches // (patch_size * patch_size )
972944
973- # Reshape to [b, c, patch_size, patch_size, f, h, w]
974- x = x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
945+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
946+ x = x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
975947
976- # Rearrange to [b, c, f, h * patch_size, w * patch_size]
977- x = x .permute (0 , 1 , 4 , 5 , 2 , 6 , 3 ).contiguous ()
978- x = x .view (batch_size , channels , frames , height * patch_size , width * patch_size )
948+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
949+ x = x .permute (0 , 1 , 4 , 5 , 3 , 6 , 2 ).contiguous ()
950+ x = x .view (batch_size , channels , frames , height * patch_size , width * patch_size )
979951
980952 return x
981953
@@ -1044,7 +1016,6 @@ def __init__(
10441016 patch_size : Optional [int ] = None ,
10451017 scale_factor_temporal : Optional [int ] = 4 ,
10461018 scale_factor_spatial : Optional [int ] = 8 ,
1047- clip_output : bool = True ,
10481019 ) -> None :
10491020 super ().__init__ ()
10501021
@@ -1244,10 +1215,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12441215 out_ = self .decoder (x [:, :, i : i + 1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
12451216 out = torch .cat ([out , out_ ], 2 )
12461217
1247- if self .config .clip_output :
1248- out = torch .clamp (out , min = - 1.0 , max = 1.0 )
12491218 if self .config .patch_size is not None :
12501219 out = unpatchify (out , patch_size = self .config .patch_size )
1220+
1221+ out = torch .clamp (out , min = - 1.0 , max = 1.0 )
1222+
12511223 self .clear_cache ()
12521224 if not return_dict :
12531225 return (out ,)
0 commit comments