@@ -730,6 +730,76 @@ def __init__(
730730 base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731731 )
732732
733+ self .spatial_compression_ratio = 2 ** len (self .temperal_downsample )
734+
735+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
736+ # to perform decoding of a single video latent at a time.
737+ self .use_slicing = False
738+
739+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
740+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
741+ # intermediate tiles together, the memory requirement can be lowered.
742+ self .use_tiling = False
743+
744+ # The minimal tile height and width for spatial tiling to be used
745+ self .tile_sample_min_height = 256
746+ self .tile_sample_min_width = 256
747+
748+ # The minimal distance between two spatial tiles
749+ self .tile_sample_stride_height = 192
750+ self .tile_sample_stride_width = 192
751+
752+ def enable_tiling (
753+ self ,
754+ tile_sample_min_height : Optional [int ] = None ,
755+ tile_sample_min_width : Optional [int ] = None ,
756+ tile_sample_stride_height : Optional [float ] = None ,
757+ tile_sample_stride_width : Optional [float ] = None ,
758+ ) -> None :
759+ r"""
760+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
761+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
762+ processing larger images.
763+
764+ Args:
765+ tile_sample_min_height (`int`, *optional*):
766+ The minimum height required for a sample to be separated into tiles across the height dimension.
767+ tile_sample_min_width (`int`, *optional*):
768+ The minimum width required for a sample to be separated into tiles across the width dimension.
769+ tile_sample_stride_height (`int`, *optional*):
770+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
771+ no tiling artifacts produced across the height dimension.
772+ tile_sample_stride_width (`int`, *optional*):
773+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
774+ artifacts produced across the width dimension.
775+ """
776+ self .use_tiling = True
777+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
778+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
779+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
780+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
781+
782+ def disable_tiling (self ) -> None :
783+ r"""
784+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
785+ decoding in one step.
786+ """
787+ self .use_tiling = False
788+
789+ def enable_slicing (self ) -> None :
790+ r"""
791+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
792+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
793+ """
794+ self .use_slicing = True
795+
796+ def disable_slicing (self ) -> None :
797+ r"""
798+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
799+ decoding in one step.
800+ """
801+ self .use_slicing = False
802+
733803 def clear_cache (self ):
734804 def _count_conv3d (model ):
735805 count = 0
@@ -746,11 +816,14 @@ def _count_conv3d(model):
746816 self ._enc_conv_idx = [0 ]
747817 self ._enc_feat_map = [None ] * self ._enc_conv_num
748818
749- def _encode (self , x : torch .Tensor ) -> torch .Tensor :
819+ def _encode (self , x : torch .Tensor ):
820+ _ , _ , num_frame , height , width = x .shape
821+
822+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
823+ return self .tiled_encode (x )
824+
750825 self .clear_cache ()
751- ## cache
752- t = x .shape [2 ]
753- iter_ = 1 + (t - 1 ) // 4
826+ iter_ = 1 + (num_frame - 1 ) // 4
754827 for i in range (iter_ ):
755828 self ._enc_conv_idx = [0 ]
756829 if i == 0 :
@@ -764,9 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
764837 out = torch .cat ([out , out_ ], 2 )
765838
766839 enc = self .quant_conv (out )
767- mu , logvar = enc [:, : self .z_dim , :, :, :], enc [:, self .z_dim :, :, :, :]
768- enc = torch .cat ([mu , logvar ], dim = 1 )
769- self .clear_cache ()
770840 return enc
771841
772842 @apply_forward_hook
@@ -785,18 +855,28 @@ def encode(
785855 The latent representations of the encoded videos. If `return_dict` is True, a
786856 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787857 """
788- h = self ._encode (x )
858+ if self .use_slicing and x .shape [0 ] > 1 :
859+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
860+ h = torch .cat (encoded_slices )
861+ else :
862+ h = self ._encode (x )
789863 posterior = DiagonalGaussianDistribution (h )
864+
790865 if not return_dict :
791866 return (posterior ,)
792867 return AutoencoderKLOutput (latent_dist = posterior )
793868
794- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
795- self .clear_cache ()
869+ def _decode (self , z : torch .Tensor , return_dict : bool = True ):
870+ _ , _ , num_frame , height , width = z .shape
871+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
872+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
873+
874+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
875+ return self .tiled_decode (z , return_dict = return_dict )
796876
797- iter_ = z . shape [ 2 ]
877+ self . clear_cache ()
798878 x = self .post_quant_conv (z )
799- for i in range (iter_ ):
879+ for i in range (num_frame ):
800880 self ._conv_idx = [0 ]
801881 if i == 0 :
802882 out = self .decoder (x [:, :, i : i + 1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
@@ -826,12 +906,159 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826906 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827907 returned.
828908 """
829- decoded = self ._decode (z ).sample
909+ if self .use_slicing and z .shape [0 ] > 1 :
910+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
911+ decoded = torch .cat (decoded_slices )
912+ else :
913+ decoded = self ._decode (z ).sample
914+
830915 if not return_dict :
831916 return (decoded ,)
832-
833917 return DecoderOutput (sample = decoded )
834918
919+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
920+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
921+ for y in range (blend_extent ):
922+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
923+ y / blend_extent
924+ )
925+ return b
926+
927+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
928+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
929+ for x in range (blend_extent ):
930+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
931+ x / blend_extent
932+ )
933+ return b
934+
935+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
936+ r"""Encode a batch of images using a tiled encoder.
937+
938+ Args:
939+ x (`torch.Tensor`): Input batch of videos.
940+
941+ Returns:
942+ `torch.Tensor`:
943+ The latent representation of the encoded videos.
944+ """
945+ _ , _ , num_frames , height , width = x .shape
946+ latent_height = height // self .spatial_compression_ratio
947+ latent_width = width // self .spatial_compression_ratio
948+
949+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
950+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
951+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
952+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
953+
954+ blend_height = tile_latent_min_height - tile_latent_stride_height
955+ blend_width = tile_latent_min_width - tile_latent_stride_width
956+
957+ # Split x into overlapping tiles and encode them separately.
958+ # The tiles have an overlap to avoid seams between tiles.
959+ rows = []
960+ for i in range (0 , height , self .tile_sample_stride_height ):
961+ row = []
962+ for j in range (0 , width , self .tile_sample_stride_width ):
963+ self .clear_cache ()
964+ time = []
965+ frame_range = 1 + (num_frames - 1 ) // 4
966+ for k in range (frame_range ):
967+ self ._enc_conv_idx = [0 ]
968+ if k == 0 :
969+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
970+ else :
971+ tile = x [
972+ :,
973+ :,
974+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
975+ i : i + self .tile_sample_min_height ,
976+ j : j + self .tile_sample_min_width ,
977+ ]
978+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
979+ tile = self .quant_conv (tile )
980+ time .append (tile )
981+ row .append (torch .cat (time , dim = 2 ))
982+ rows .append (row )
983+
984+ result_rows = []
985+ for i , row in enumerate (rows ):
986+ result_row = []
987+ for j , tile in enumerate (row ):
988+ # blend the above tile and the left tile
989+ # to the current tile and add the current tile to the result row
990+ if i > 0 :
991+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
992+ if j > 0 :
993+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
994+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
995+ result_rows .append (torch .cat (result_row , dim = - 1 ))
996+
997+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
998+ return enc
999+
1000+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1001+ r"""
1002+ Decode a batch of images using a tiled decoder.
1003+
1004+ Args:
1005+ z (`torch.Tensor`): Input batch of latent vectors.
1006+ return_dict (`bool`, *optional*, defaults to `True`):
1007+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1008+
1009+ Returns:
1010+ [`~models.vae.DecoderOutput`] or `tuple`:
1011+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1012+ returned.
1013+ """
1014+ _ , _ , num_frames , height , width = z .shape
1015+ sample_height = height * self .spatial_compression_ratio
1016+ sample_width = width * self .spatial_compression_ratio
1017+
1018+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1019+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1020+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1021+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1022+
1023+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1024+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1025+
1026+ # Split z into overlapping tiles and decode them separately.
1027+ # The tiles have an overlap to avoid seams between tiles.
1028+ rows = []
1029+ for i in range (0 , height , tile_latent_stride_height ):
1030+ row = []
1031+ for j in range (0 , width , tile_latent_stride_width ):
1032+ self .clear_cache ()
1033+ time = []
1034+ for k in range (num_frames ):
1035+ self ._conv_idx = [0 ]
1036+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
1037+ tile = self .post_quant_conv (tile )
1038+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1039+ time .append (decoded )
1040+ row .append (torch .cat (time , dim = 2 ))
1041+ rows .append (row )
1042+
1043+ result_rows = []
1044+ for i , row in enumerate (rows ):
1045+ result_row = []
1046+ for j , tile in enumerate (row ):
1047+ # blend the above tile and the left tile
1048+ # to the current tile and add the current tile to the result row
1049+ if i > 0 :
1050+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1051+ if j > 0 :
1052+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1053+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1054+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1055+
1056+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1057+
1058+ if not return_dict :
1059+ return (dec ,)
1060+ return DecoderOutput (sample = dec )
1061+
8351062 def forward (
8361063 self ,
8371064 sample : torch .Tensor ,
0 commit comments