@@ -730,6 +730,77 @@ def __init__(
730730 base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731731 )
732732
733+ self .temporal_compression_ratio = 2 ** sum (self .temperal_downsample )
734+ self .spatial_compression_ratio = 2 ** len (self .temperal_downsample )
735+
736+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737+ # to perform decoding of a single video latent at a time.
738+ self .use_slicing = False
739+
740+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742+ # intermediate tiles together, the memory requirement can be lowered.
743+ self .use_tiling = False
744+
745+ # The minimal tile height and width for spatial tiling to be used
746+ self .tile_sample_min_height = 256
747+ self .tile_sample_min_width = 256
748+
749+ # The minimal distance between two spatial tiles
750+ self .tile_sample_stride_height = 192
751+ self .tile_sample_stride_width = 192
752+
753+ def enable_tiling (
754+ self ,
755+ tile_sample_min_height : Optional [int ] = None ,
756+ tile_sample_min_width : Optional [int ] = None ,
757+ tile_sample_stride_height : Optional [float ] = None ,
758+ tile_sample_stride_width : Optional [float ] = None ,
759+ ) -> None :
760+ r"""
761+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
762+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
763+ processing larger images.
764+
765+ Args:
766+ tile_sample_min_height (`int`, *optional*):
767+ The minimum height required for a sample to be separated into tiles across the height dimension.
768+ tile_sample_min_width (`int`, *optional*):
769+ The minimum width required for a sample to be separated into tiles across the width dimension.
770+ tile_sample_stride_height (`int`, *optional*):
771+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
772+ no tiling artifacts produced across the height dimension.
773+ tile_sample_stride_width (`int`, *optional*):
774+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
775+ artifacts produced across the width dimension.
776+ """
777+ self .use_tiling = True
778+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
779+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
780+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
781+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
782+
783+ def disable_tiling (self ) -> None :
784+ r"""
785+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
786+ decoding in one step.
787+ """
788+ self .use_tiling = False
789+
790+ def enable_slicing (self ) -> None :
791+ r"""
792+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
793+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
794+ """
795+ self .use_slicing = True
796+
797+ def disable_slicing (self ) -> None :
798+ r"""
799+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
800+ decoding in one step.
801+ """
802+ self .use_slicing = False
803+
733804 def clear_cache (self ):
734805 def _count_conv3d (model ):
735806 count = 0
@@ -746,7 +817,7 @@ def _count_conv3d(model):
746817 self ._enc_conv_idx = [0 ]
747818 self ._enc_feat_map = [None ] * self ._enc_conv_num
748819
749- def _encode (self , x : torch .Tensor ) -> torch .Tensor :
820+ def vanilla_encode (self , x : torch .Tensor ) -> torch .Tensor :
750821 self .clear_cache ()
751822 ## cache
752823 t = x .shape [2 ]
@@ -769,6 +840,12 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
769840 self .clear_cache ()
770841 return enc
771842
843+ def _encode (self , x : torch .Tensor ):
844+ _ , _ , _ , height , width = x .shape
845+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
846+ return self .tiled_encode (x )
847+ return self .vanilla_encode (x )
848+
772849 @apply_forward_hook
773850 def encode (
774851 self , x : torch .Tensor , return_dict : bool = True
@@ -785,13 +862,18 @@ def encode(
785862 The latent representations of the encoded videos. If `return_dict` is True, a
786863 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787864 """
788- h = self ._encode (x )
865+ if self .use_slicing and x .shape [0 ] > 1 :
866+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
867+ h = torch .cat (encoded_slices )
868+ else :
869+ h = self ._encode (x )
789870 posterior = DiagonalGaussianDistribution (h )
871+
790872 if not return_dict :
791873 return (posterior ,)
792874 return AutoencoderKLOutput (latent_dist = posterior )
793875
794- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
876+ def vanilla_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
795877 self .clear_cache ()
796878
797879 iter_ = z .shape [2 ]
@@ -811,6 +893,15 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
811893
812894 return DecoderOutput (sample = out )
813895
896+ def _decode (self , z : torch .Tensor , return_dict : bool = True ):
897+ _ , _ , _ , height , width = z .shape
898+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
899+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
900+
901+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
902+ return self .tiled_decode (z , return_dict = return_dict )
903+ return self .vanilla_decode (z , return_dict = return_dict )
904+
814905 @apply_forward_hook
815906 def decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
816907 r"""
@@ -826,12 +917,167 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826917 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827918 returned.
828919 """
829- decoded = self ._decode (z ).sample
920+ if self .use_slicing and z .shape [0 ] > 1 :
921+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
922+ decoded = torch .cat (decoded_slices )
923+ else :
924+ decoded = self ._decode (z ).sample
925+
830926 if not return_dict :
831927 return (decoded ,)
832-
833928 return DecoderOutput (sample = decoded )
834929
930+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
931+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
932+ for y in range (blend_extent ):
933+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
934+ y / blend_extent
935+ )
936+ return b
937+
938+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
939+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
940+ for x in range (blend_extent ):
941+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
942+ x / blend_extent
943+ )
944+ return b
945+
946+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
947+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
948+ for x in range (blend_extent ):
949+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
950+ x / blend_extent
951+ )
952+ return b
953+
954+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
955+ r"""Encode a batch of images using a tiled encoder.
956+
957+ Args:
958+ x (`torch.Tensor`): Input batch of videos.
959+
960+ Returns:
961+ `torch.Tensor`:
962+ The latent representation of the encoded videos.
963+ """
964+ _ , _ , num_frames , height , width = x .shape
965+ latent_height = height // self .spatial_compression_ratio
966+ latent_width = width // self .spatial_compression_ratio
967+
968+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
969+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
970+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
971+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
972+
973+ blend_height = tile_latent_min_height - tile_latent_stride_height
974+ blend_width = tile_latent_min_width - tile_latent_stride_width
975+
976+ # Split x into overlapping tiles and encode them separately.
977+ # The tiles have an overlap to avoid seams between tiles.
978+ rows = []
979+ for i in range (0 , height , self .tile_sample_stride_height ):
980+ row = []
981+ for j in range (0 , width , self .tile_sample_stride_width ):
982+ self .clear_cache ()
983+ time = []
984+ frame_range = 1 + (num_frames - 1 ) // 4
985+ for k in range (frame_range ):
986+ self ._enc_conv_idx = [0 ]
987+ if k == 0 :
988+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
989+ else :
990+ tile = x [
991+ :,
992+ :,
993+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
994+ i : i + self .tile_sample_min_height ,
995+ j : j + self .tile_sample_min_width ,
996+ ]
997+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
998+ tile = self .quant_conv (tile )
999+ time .append (tile )
1000+ row .append (torch .cat (time , dim = 2 ))
1001+ rows .append (row )
1002+
1003+ result_rows = []
1004+ for i , row in enumerate (rows ):
1005+ result_row = []
1006+ for j , tile in enumerate (row ):
1007+ # blend the above tile and the left tile
1008+ # to the current tile and add the current tile to the result row
1009+ if i > 0 :
1010+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1011+ if j > 0 :
1012+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1013+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
1014+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1015+
1016+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
1017+ return enc
1018+
1019+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1020+ r"""
1021+ Decode a batch of images using a tiled decoder.
1022+
1023+ Args:
1024+ z (`torch.Tensor`): Input batch of latent vectors.
1025+ return_dict (`bool`, *optional*, defaults to `True`):
1026+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1027+
1028+ Returns:
1029+ [`~models.vae.DecoderOutput`] or `tuple`:
1030+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1031+ returned.
1032+ """
1033+ _ , _ , num_frames , height , width = z .shape
1034+ sample_height = height * self .spatial_compression_ratio
1035+ sample_width = width * self .spatial_compression_ratio
1036+
1037+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1038+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1039+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1040+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1041+
1042+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1043+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1044+
1045+ # Split z into overlapping tiles and decode them separately.
1046+ # The tiles have an overlap to avoid seams between tiles.
1047+ rows = []
1048+ for i in range (0 , height , tile_latent_stride_height ):
1049+ row = []
1050+ for j in range (0 , width , tile_latent_stride_width ):
1051+ self .clear_cache ()
1052+ time = []
1053+ for k in range (num_frames ):
1054+ self ._conv_idx = [0 ]
1055+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
1056+ tile = self .post_quant_conv (tile )
1057+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1058+ time .append (decoded )
1059+ row .append (torch .cat (time , dim = 2 ))
1060+ rows .append (row )
1061+
1062+ result_rows = []
1063+ for i , row in enumerate (rows ):
1064+ result_row = []
1065+ for j , tile in enumerate (row ):
1066+ # blend the above tile and the left tile
1067+ # to the current tile and add the current tile to the result row
1068+ if i > 0 :
1069+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1070+ if j > 0 :
1071+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1072+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1073+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1074+
1075+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1076+
1077+ if not return_dict :
1078+ return (dec ,)
1079+ return DecoderOutput (sample = dec )
1080+
8351081 def forward (
8361082 self ,
8371083 sample : torch .Tensor ,
0 commit comments