@@ -730,6 +730,76 @@ def __init__(
730730            base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout 
731731        )
732732
733+         self .spatial_compression_ratio  =  2  **  len (self .temperal_downsample )
734+ 
735+         # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension 
736+         # to perform decoding of a single video latent at a time. 
737+         self .use_slicing  =  False 
738+ 
739+         # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent 
740+         # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the 
741+         # intermediate tiles together, the memory requirement can be lowered. 
742+         self .use_tiling  =  False 
743+ 
744+         # The minimal tile height and width for spatial tiling to be used 
745+         self .tile_sample_min_height  =  256 
746+         self .tile_sample_min_width  =  256 
747+ 
748+         # The minimal distance between two spatial tiles 
749+         self .tile_sample_stride_height  =  192 
750+         self .tile_sample_stride_width  =  192 
751+ 
752+     def  enable_tiling (
753+         self ,
754+         tile_sample_min_height : Optional [int ] =  None ,
755+         tile_sample_min_width : Optional [int ] =  None ,
756+         tile_sample_stride_height : Optional [float ] =  None ,
757+         tile_sample_stride_width : Optional [float ] =  None ,
758+     ) ->  None :
759+         r""" 
760+         Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 
761+         compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 
762+         processing larger images. 
763+ 
764+         Args: 
765+             tile_sample_min_height (`int`, *optional*): 
766+                 The minimum height required for a sample to be separated into tiles across the height dimension. 
767+             tile_sample_min_width (`int`, *optional*): 
768+                 The minimum width required for a sample to be separated into tiles across the width dimension. 
769+             tile_sample_stride_height (`int`, *optional*): 
770+                 The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are 
771+                 no tiling artifacts produced across the height dimension. 
772+             tile_sample_stride_width (`int`, *optional*): 
773+                 The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling 
774+                 artifacts produced across the width dimension. 
775+         """ 
776+         self .use_tiling  =  True 
777+         self .tile_sample_min_height  =  tile_sample_min_height  or  self .tile_sample_min_height 
778+         self .tile_sample_min_width  =  tile_sample_min_width  or  self .tile_sample_min_width 
779+         self .tile_sample_stride_height  =  tile_sample_stride_height  or  self .tile_sample_stride_height 
780+         self .tile_sample_stride_width  =  tile_sample_stride_width  or  self .tile_sample_stride_width 
781+ 
782+     def  disable_tiling (self ) ->  None :
783+         r""" 
784+         Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing 
785+         decoding in one step. 
786+         """ 
787+         self .use_tiling  =  False 
788+ 
789+     def  enable_slicing (self ) ->  None :
790+         r""" 
791+         Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 
792+         compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 
793+         """ 
794+         self .use_slicing  =  True 
795+ 
796+     def  disable_slicing (self ) ->  None :
797+         r""" 
798+         Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing 
799+         decoding in one step. 
800+         """ 
801+         self .use_slicing  =  False 
802+ 
733803    def  clear_cache (self ):
734804        def  _count_conv3d (model ):
735805            count  =  0 
@@ -746,11 +816,14 @@ def _count_conv3d(model):
746816        self ._enc_conv_idx  =  [0 ]
747817        self ._enc_feat_map  =  [None ] *  self ._enc_conv_num 
748818
749-     def  _encode (self , x : torch .Tensor ) ->  torch .Tensor :
819+     def  _encode (self , x : torch .Tensor ):
820+         _ , _ , num_frame , height , width  =  x .shape 
821+ 
822+         if  self .use_tiling  and  (width  >  self .tile_sample_min_width  or  height  >  self .tile_sample_min_height ):
823+             return  self .tiled_encode (x )
824+ 
750825        self .clear_cache ()
751-         ## cache 
752-         t  =  x .shape [2 ]
753-         iter_  =  1  +  (t  -  1 ) //  4 
826+         iter_  =  1  +  (num_frame  -  1 ) //  4 
754827        for  i  in  range (iter_ ):
755828            self ._enc_conv_idx  =  [0 ]
756829            if  i  ==  0 :
@@ -764,8 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
764837                out  =  torch .cat ([out , out_ ], 2 )
765838
766839        enc  =  self .quant_conv (out )
767-         mu , logvar  =  enc [:, : self .z_dim , :, :, :], enc [:, self .z_dim  :, :, :, :]
768-         enc  =  torch .cat ([mu , logvar ], dim = 1 )
769840        self .clear_cache ()
770841        return  enc 
771842
@@ -785,18 +856,28 @@ def encode(
785856                The latent representations of the encoded videos. If `return_dict` is True, a 
786857                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. 
787858        """ 
788-         h  =  self ._encode (x )
859+         if  self .use_slicing  and  x .shape [0 ] >  1 :
860+             encoded_slices  =  [self ._encode (x_slice ) for  x_slice  in  x .split (1 )]
861+             h  =  torch .cat (encoded_slices )
862+         else :
863+             h  =  self ._encode (x )
789864        posterior  =  DiagonalGaussianDistribution (h )
865+ 
790866        if  not  return_dict :
791867            return  (posterior ,)
792868        return  AutoencoderKLOutput (latent_dist = posterior )
793869
794-     def  _decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
795-         self .clear_cache ()
870+     def  _decode (self , z : torch .Tensor , return_dict : bool  =  True ):
871+         _ , _ , num_frame , height , width  =  z .shape 
872+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
873+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
874+ 
875+         if  self .use_tiling  and  (width  >  tile_latent_min_width  or  height  >  tile_latent_min_height ):
876+             return  self .tiled_decode (z , return_dict = return_dict )
796877
797-         iter_   =   z . shape [ 2 ] 
878+         self . clear_cache () 
798879        x  =  self .post_quant_conv (z )
799-         for  i  in  range (iter_ ):
880+         for  i  in  range (num_frame ):
800881            self ._conv_idx  =  [0 ]
801882            if  i  ==  0 :
802883                out  =  self .decoder (x [:, :, i  : i  +  1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
@@ -826,12 +907,161 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826907                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 
827908                returned. 
828909        """ 
829-         decoded  =  self ._decode (z ).sample 
910+         if  self .use_slicing  and  z .shape [0 ] >  1 :
911+             decoded_slices  =  [self ._decode (z_slice ).sample  for  z_slice  in  z .split (1 )]
912+             decoded  =  torch .cat (decoded_slices )
913+         else :
914+             decoded  =  self ._decode (z ).sample 
915+ 
830916        if  not  return_dict :
831917            return  (decoded ,)
832- 
833918        return  DecoderOutput (sample = decoded )
834919
920+     def  blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
921+         blend_extent  =  min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
922+         for  y  in  range (blend_extent ):
923+             b [:, :, :, y , :] =  a [:, :, :, - blend_extent  +  y , :] *  (1  -  y  /  blend_extent ) +  b [:, :, :, y , :] *  (
924+                 y  /  blend_extent 
925+             )
926+         return  b 
927+ 
928+     def  blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) ->  torch .Tensor :
929+         blend_extent  =  min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
930+         for  x  in  range (blend_extent ):
931+             b [:, :, :, :, x ] =  a [:, :, :, :, - blend_extent  +  x ] *  (1  -  x  /  blend_extent ) +  b [:, :, :, :, x ] *  (
932+                 x  /  blend_extent 
933+             )
934+         return  b 
935+ 
936+     def  tiled_encode (self , x : torch .Tensor ) ->  AutoencoderKLOutput :
937+         r"""Encode a batch of images using a tiled encoder. 
938+ 
939+         Args: 
940+             x (`torch.Tensor`): Input batch of videos. 
941+ 
942+         Returns: 
943+             `torch.Tensor`: 
944+                 The latent representation of the encoded videos. 
945+         """ 
946+         _ , _ , num_frames , height , width  =  x .shape 
947+         latent_height  =  height  //  self .spatial_compression_ratio 
948+         latent_width  =  width  //  self .spatial_compression_ratio 
949+ 
950+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
951+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
952+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
953+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
954+ 
955+         blend_height  =  tile_latent_min_height  -  tile_latent_stride_height 
956+         blend_width  =  tile_latent_min_width  -  tile_latent_stride_width 
957+ 
958+         # Split x into overlapping tiles and encode them separately. 
959+         # The tiles have an overlap to avoid seams between tiles. 
960+         rows  =  []
961+         for  i  in  range (0 , height , self .tile_sample_stride_height ):
962+             row  =  []
963+             for  j  in  range (0 , width , self .tile_sample_stride_width ):
964+                 self .clear_cache ()
965+                 time  =  []
966+                 frame_range  =  1  +  (num_frames  -  1 ) //  4 
967+                 for  k  in  range (frame_range ):
968+                     self ._enc_conv_idx  =  [0 ]
969+                     if  k  ==  0 :
970+                         tile  =  x [:, :, :1 , i  : i  +  self .tile_sample_min_height , j  : j  +  self .tile_sample_min_width ]
971+                     else :
972+                         tile  =  x [
973+                             :,
974+                             :,
975+                             1  +  4  *  (k  -  1 ) : 1  +  4  *  k ,
976+                             i  : i  +  self .tile_sample_min_height ,
977+                             j  : j  +  self .tile_sample_min_width ,
978+                         ]
979+                     tile  =  self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
980+                     tile  =  self .quant_conv (tile )
981+                     time .append (tile )
982+                 row .append (torch .cat (time , dim = 2 ))
983+             rows .append (row )
984+         self .clear_cache ()
985+ 
986+         result_rows  =  []
987+         for  i , row  in  enumerate (rows ):
988+             result_row  =  []
989+             for  j , tile  in  enumerate (row ):
990+                 # blend the above tile and the left tile 
991+                 # to the current tile and add the current tile to the result row 
992+                 if  i  >  0 :
993+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
994+                 if  j  >  0 :
995+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
996+                 result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
997+             result_rows .append (torch .cat (result_row , dim = - 1 ))
998+ 
999+         enc  =  torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
1000+         return  enc 
1001+ 
1002+     def  tiled_decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
1003+         r""" 
1004+         Decode a batch of images using a tiled decoder. 
1005+ 
1006+         Args: 
1007+             z (`torch.Tensor`): Input batch of latent vectors. 
1008+             return_dict (`bool`, *optional*, defaults to `True`): 
1009+                 Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 
1010+ 
1011+         Returns: 
1012+             [`~models.vae.DecoderOutput`] or `tuple`: 
1013+                 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 
1014+                 returned. 
1015+         """ 
1016+         _ , _ , num_frames , height , width  =  z .shape 
1017+         sample_height  =  height  *  self .spatial_compression_ratio 
1018+         sample_width  =  width  *  self .spatial_compression_ratio 
1019+ 
1020+         tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
1021+         tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
1022+         tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
1023+         tile_latent_stride_width  =  self .tile_sample_stride_width  //  self .spatial_compression_ratio 
1024+ 
1025+         blend_height  =  self .tile_sample_min_height  -  self .tile_sample_stride_height 
1026+         blend_width  =  self .tile_sample_min_width  -  self .tile_sample_stride_width 
1027+ 
1028+         # Split z into overlapping tiles and decode them separately. 
1029+         # The tiles have an overlap to avoid seams between tiles. 
1030+         rows  =  []
1031+         for  i  in  range (0 , height , tile_latent_stride_height ):
1032+             row  =  []
1033+             for  j  in  range (0 , width , tile_latent_stride_width ):
1034+                 self .clear_cache ()
1035+                 time  =  []
1036+                 for  k  in  range (num_frames ):
1037+                     self ._conv_idx  =  [0 ]
1038+                     tile  =  z [:, :, k  : k  +  1 , i  : i  +  tile_latent_min_height , j  : j  +  tile_latent_min_width ]
1039+                     tile  =  self .post_quant_conv (tile )
1040+                     decoded  =  self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1041+                     time .append (decoded )
1042+                 row .append (torch .cat (time , dim = 2 ))
1043+             rows .append (row )
1044+         self .clear_cache ()
1045+ 
1046+         result_rows  =  []
1047+         for  i , row  in  enumerate (rows ):
1048+             result_row  =  []
1049+             for  j , tile  in  enumerate (row ):
1050+                 # blend the above tile and the left tile 
1051+                 # to the current tile and add the current tile to the result row 
1052+                 if  i  >  0 :
1053+                     tile  =  self .blend_v (rows [i  -  1 ][j ], tile , blend_height )
1054+                 if  j  >  0 :
1055+                     tile  =  self .blend_h (row [j  -  1 ], tile , blend_width )
1056+                 result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1057+             result_rows .append (torch .cat (result_row , dim = - 1 ))
1058+ 
1059+         dec  =  torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1060+ 
1061+         if  not  return_dict :
1062+             return  (dec ,)
1063+         return  DecoderOutput (sample = dec )
1064+ 
8351065    def  forward (
8361066        self ,
8371067        sample : torch .Tensor ,
0 commit comments