@@ -1277,6 +1277,9 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12771277            `torch.Tensor`: 
12781278                The latent representation of the encoded videos. 
12791279        """ 
1280+         if  self .config .patch_size  is  not None :
1281+             x  =  patchify (x , patch_size = self .config .patch_size )
1282+             
12801283        _ , _ , num_frames , height , width  =  x .shape 
12811284        latent_height  =  height  //  self .spatial_compression_ratio 
12821285        latent_width  =  width  //  self .spatial_compression_ratio 
@@ -1311,7 +1314,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13111314                            j  : j  +  self .tile_sample_min_width ,
13121315                        ]
13131316                    tile  =  self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
1314-                     tile  =  self .quant_conv (tile )
1317+                     #  tile = self.quant_conv(tile)
13151318                    time .append (tile )
13161319                row .append (torch .cat (time , dim = 2 ))
13171320            rows .append (row )
@@ -1331,6 +1334,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13311334            result_rows .append (torch .cat (result_row , dim = - 1 ))
13321335
13331336        enc  =  torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
1337+         enc  =  self .quant_conv (enc )
13341338        return  enc 
13351339
13361340    def  tiled_decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
@@ -1347,6 +1351,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13471351                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 
13481352                returned. 
13491353        """ 
1354+         z  =  self .post_quant_conv (z )
1355+ 
13501356        _ , _ , num_frames , height , width  =  z .shape 
13511357        sample_height  =  height  *  self .spatial_compression_ratio 
13521358        sample_width  =  width  *  self .spatial_compression_ratio 
@@ -1370,8 +1376,11 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13701376                for  k  in  range (num_frames ):
13711377                    self ._conv_idx  =  [0 ]
13721378                    tile  =  z [:, :, k  : k  +  1 , i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ]
1373-                     tile  =  self .post_quant_conv (tile )
1374-                     decoded  =  self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1379+                     # tile = self.post_quant_conv(tile) 
1380+                     if  k  ==  0 :
1381+                         decoded  =  self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx ,first_chunk = True )
1382+                     else :
1383+                         decoded  =  self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
13751384                    time .append (decoded )
13761385                row .append (torch .cat (time , dim = 2 ))
13771386            rows .append (row )
@@ -1392,6 +1401,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13921401
13931402        dec  =  torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
13941403
1404+         if  self .config .patch_size  is  not None :
1405+             dec  =  unpatchify (dec , patch_size = self .config .patch_size )
1406+         dec  =  torch .clamp (dec , min = - 1.0 , max = 1.0 )
1407+ 
13951408        if  not  return_dict :
13961409            return  (dec ,)
13971410        return  DecoderOutput (sample = dec )
0 commit comments