@@ -730,6 +730,76 @@ def __init__(
730730 base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731731 )
732732
733+ self .spatial_compression_ratio = 2 ** len (self .temperal_downsample )
734+
735+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
736+ # to perform decoding of a single video latent at a time.
737+ self .use_slicing = False
738+
739+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
740+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
741+ # intermediate tiles together, the memory requirement can be lowered.
742+ self .use_tiling = False
743+
744+ # The minimal tile height and width for spatial tiling to be used
745+ self .tile_sample_min_height = 256
746+ self .tile_sample_min_width = 256
747+
748+ # The minimal distance between two spatial tiles
749+ self .tile_sample_stride_height = 192
750+ self .tile_sample_stride_width = 192
751+
752+ def enable_tiling (
753+ self ,
754+ tile_sample_min_height : Optional [int ] = None ,
755+ tile_sample_min_width : Optional [int ] = None ,
756+ tile_sample_stride_height : Optional [float ] = None ,
757+ tile_sample_stride_width : Optional [float ] = None ,
758+ ) -> None :
759+ r"""
760+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
761+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
762+ processing larger images.
763+
764+ Args:
765+ tile_sample_min_height (`int`, *optional*):
766+ The minimum height required for a sample to be separated into tiles across the height dimension.
767+ tile_sample_min_width (`int`, *optional*):
768+ The minimum width required for a sample to be separated into tiles across the width dimension.
769+ tile_sample_stride_height (`int`, *optional*):
770+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
771+ no tiling artifacts produced across the height dimension.
772+ tile_sample_stride_width (`int`, *optional*):
773+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
774+ artifacts produced across the width dimension.
775+ """
776+ self .use_tiling = True
777+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
778+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
779+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
780+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
781+
782+ def disable_tiling (self ) -> None :
783+ r"""
784+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
785+ decoding in one step.
786+ """
787+ self .use_tiling = False
788+
789+ def enable_slicing (self ) -> None :
790+ r"""
791+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
792+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
793+ """
794+ self .use_slicing = True
795+
796+ def disable_slicing (self ) -> None :
797+ r"""
798+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
799+ decoding in one step.
800+ """
801+ self .use_slicing = False
802+
733803 def clear_cache (self ):
734804 def _count_conv3d (model ):
735805 count = 0
@@ -746,11 +816,14 @@ def _count_conv3d(model):
746816 self ._enc_conv_idx = [0 ]
747817 self ._enc_feat_map = [None ] * self ._enc_conv_num
748818
749- def _encode (self , x : torch .Tensor ) -> torch .Tensor :
819+ def _encode (self , x : torch .Tensor ):
820+ _ , _ , num_frame , height , width = x .shape
821+
822+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
823+ return self .tiled_encode (x )
824+
750825 self .clear_cache ()
751- ## cache
752- t = x .shape [2 ]
753- iter_ = 1 + (t - 1 ) // 4
826+ iter_ = 1 + (num_frame - 1 ) // 4
754827 for i in range (iter_ ):
755828 self ._enc_conv_idx = [0 ]
756829 if i == 0 :
@@ -764,8 +837,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
764837 out = torch .cat ([out , out_ ], 2 )
765838
766839 enc = self .quant_conv (out )
767- mu , logvar = enc [:, : self .z_dim , :, :, :], enc [:, self .z_dim :, :, :, :]
768- enc = torch .cat ([mu , logvar ], dim = 1 )
769840 self .clear_cache ()
770841 return enc
771842
@@ -785,18 +856,28 @@ def encode(
785856 The latent representations of the encoded videos. If `return_dict` is True, a
786857 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787858 """
788- h = self ._encode (x )
859+ if self .use_slicing and x .shape [0 ] > 1 :
860+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
861+ h = torch .cat (encoded_slices )
862+ else :
863+ h = self ._encode (x )
789864 posterior = DiagonalGaussianDistribution (h )
865+
790866 if not return_dict :
791867 return (posterior ,)
792868 return AutoencoderKLOutput (latent_dist = posterior )
793869
794- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
795- self .clear_cache ()
870+ def _decode (self , z : torch .Tensor , return_dict : bool = True ):
871+ _ , _ , num_frame , height , width = z .shape
872+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
873+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
874+
875+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
876+ return self .tiled_decode (z , return_dict = return_dict )
796877
797- iter_ = z . shape [ 2 ]
878+ self . clear_cache ()
798879 x = self .post_quant_conv (z )
799- for i in range (iter_ ):
880+ for i in range (num_frame ):
800881 self ._conv_idx = [0 ]
801882 if i == 0 :
802883 out = self .decoder (x [:, :, i : i + 1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
@@ -826,12 +907,161 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826907 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827908 returned.
828909 """
829- decoded = self ._decode (z ).sample
910+ if self .use_slicing and z .shape [0 ] > 1 :
911+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
912+ decoded = torch .cat (decoded_slices )
913+ else :
914+ decoded = self ._decode (z ).sample
915+
830916 if not return_dict :
831917 return (decoded ,)
832-
833918 return DecoderOutput (sample = decoded )
834919
920+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
921+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
922+ for y in range (blend_extent ):
923+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
924+ y / blend_extent
925+ )
926+ return b
927+
928+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
929+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
930+ for x in range (blend_extent ):
931+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
932+ x / blend_extent
933+ )
934+ return b
935+
936+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
937+ r"""Encode a batch of images using a tiled encoder.
938+
939+ Args:
940+ x (`torch.Tensor`): Input batch of videos.
941+
942+ Returns:
943+ `torch.Tensor`:
944+ The latent representation of the encoded videos.
945+ """
946+ _ , _ , num_frames , height , width = x .shape
947+ latent_height = height // self .spatial_compression_ratio
948+ latent_width = width // self .spatial_compression_ratio
949+
950+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
951+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
952+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
953+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
954+
955+ blend_height = tile_latent_min_height - tile_latent_stride_height
956+ blend_width = tile_latent_min_width - tile_latent_stride_width
957+
958+ # Split x into overlapping tiles and encode them separately.
959+ # The tiles have an overlap to avoid seams between tiles.
960+ rows = []
961+ for i in range (0 , height , self .tile_sample_stride_height ):
962+ row = []
963+ for j in range (0 , width , self .tile_sample_stride_width ):
964+ self .clear_cache ()
965+ time = []
966+ frame_range = 1 + (num_frames - 1 ) // 4
967+ for k in range (frame_range ):
968+ self ._enc_conv_idx = [0 ]
969+ if k == 0 :
970+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
971+ else :
972+ tile = x [
973+ :,
974+ :,
975+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
976+ i : i + self .tile_sample_min_height ,
977+ j : j + self .tile_sample_min_width ,
978+ ]
979+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
980+ tile = self .quant_conv (tile )
981+ time .append (tile )
982+ row .append (torch .cat (time , dim = 2 ))
983+ rows .append (row )
984+ self .clear_cache ()
985+
986+ result_rows = []
987+ for i , row in enumerate (rows ):
988+ result_row = []
989+ for j , tile in enumerate (row ):
990+ # blend the above tile and the left tile
991+ # to the current tile and add the current tile to the result row
992+ if i > 0 :
993+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
994+ if j > 0 :
995+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
996+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
997+ result_rows .append (torch .cat (result_row , dim = - 1 ))
998+
999+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
1000+ return enc
1001+
1002+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1003+ r"""
1004+ Decode a batch of images using a tiled decoder.
1005+
1006+ Args:
1007+ z (`torch.Tensor`): Input batch of latent vectors.
1008+ return_dict (`bool`, *optional*, defaults to `True`):
1009+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1010+
1011+ Returns:
1012+ [`~models.vae.DecoderOutput`] or `tuple`:
1013+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1014+ returned.
1015+ """
1016+ _ , _ , num_frames , height , width = z .shape
1017+ sample_height = height * self .spatial_compression_ratio
1018+ sample_width = width * self .spatial_compression_ratio
1019+
1020+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1021+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1022+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1023+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1024+
1025+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1026+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1027+
1028+ # Split z into overlapping tiles and decode them separately.
1029+ # The tiles have an overlap to avoid seams between tiles.
1030+ rows = []
1031+ for i in range (0 , height , tile_latent_stride_height ):
1032+ row = []
1033+ for j in range (0 , width , tile_latent_stride_width ):
1034+ self .clear_cache ()
1035+ time = []
1036+ for k in range (num_frames ):
1037+ self ._conv_idx = [0 ]
1038+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
1039+ tile = self .post_quant_conv (tile )
1040+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1041+ time .append (decoded )
1042+ row .append (torch .cat (time , dim = 2 ))
1043+ rows .append (row )
1044+ self .clear_cache ()
1045+
1046+ result_rows = []
1047+ for i , row in enumerate (rows ):
1048+ result_row = []
1049+ for j , tile in enumerate (row ):
1050+ # blend the above tile and the left tile
1051+ # to the current tile and add the current tile to the result row
1052+ if i > 0 :
1053+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1054+ if j > 0 :
1055+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1056+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1057+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1058+
1059+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1060+
1061+ if not return_dict :
1062+ return (dec ,)
1063+ return DecoderOutput (sample = dec )
1064+
8351065 def forward (
8361066 self ,
8371067 sample : torch .Tensor ,
0 commit comments