@@ -677,42 +677,7 @@ def __init__(
677677 attn_scales : List [float ] = [],
678678 temperal_downsample : List [bool ] = [False , True , True ],
679679 dropout : float = 0.0 ,
680- latents_mean : List [float ] = [
681- - 0.7571 ,
682- - 0.7089 ,
683- - 0.9113 ,
684- 0.1075 ,
685- - 0.1745 ,
686- 0.9653 ,
687- - 0.1517 ,
688- 1.5508 ,
689- 0.4134 ,
690- - 0.0715 ,
691- 0.5517 ,
692- - 0.3632 ,
693- - 0.1922 ,
694- - 0.9497 ,
695- 0.2503 ,
696- - 0.2921 ,
697- ],
698- latents_std : List [float ] = [
699- 2.8184 ,
700- 1.4541 ,
701- 2.3275 ,
702- 2.6558 ,
703- 1.2196 ,
704- 1.7708 ,
705- 2.6052 ,
706- 2.0743 ,
707- 3.2687 ,
708- 2.1526 ,
709- 2.8652 ,
710- 1.5579 ,
711- 1.6382 ,
712- 1.1253 ,
713- 2.8251 ,
714- 1.9160 ,
715- ],
680+ spatial_compression_ratio : int = 8 ,
716681 ) -> None :
717682 super ().__init__ ()
718683
@@ -730,6 +695,58 @@ def __init__(
730695 base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731696 )
732697
698+ self .spatial_compression_ratio = spatial_compression_ratio
699+
700+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
701+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
702+ # intermediate tiles together, the memory requirement can be lowered.
703+ self .use_tiling = False
704+
705+ # The minimal tile height and width for spatial tiling to be used
706+ self .tile_sample_min_height = 256
707+ self .tile_sample_min_width = 256
708+
709+ # The minimal distance between two spatial tiles
710+ self .tile_sample_stride_height = 192
711+ self .tile_sample_stride_width = 192
712+
713+ def enable_tiling (
714+ self ,
715+ tile_sample_min_height : Optional [int ] = None ,
716+ tile_sample_min_width : Optional [int ] = None ,
717+ tile_sample_stride_height : Optional [float ] = None ,
718+ tile_sample_stride_width : Optional [float ] = None ,
719+ ) -> None :
720+ r"""
721+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
722+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
723+ processing larger images.
724+
725+ Args:
726+ tile_sample_min_height (`int`, *optional*):
727+ The minimum height required for a sample to be separated into tiles across the height dimension.
728+ tile_sample_min_width (`int`, *optional*):
729+ The minimum width required for a sample to be separated into tiles across the width dimension.
730+ tile_sample_stride_height (`int`, *optional*):
731+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
732+ no tiling artifacts produced across the height dimension.
733+ tile_sample_stride_width (`int`, *optional*):
734+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
735+ artifacts produced across the width dimension.
736+ """
737+ self .use_tiling = True
738+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
739+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
740+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
741+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
742+
743+ def disable_tiling (self ) -> None :
744+ r"""
745+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
746+ decoding in one step.
747+ """
748+ self .use_tiling = False
749+
733750 def clear_cache (self ):
734751 def _count_conv3d (model ):
735752 count = 0
@@ -785,7 +802,11 @@ def encode(
785802 The latent representations of the encoded videos. If `return_dict` is True, a
786803 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787804 """
788- h = self ._encode (x )
805+ _ , _ , _ , height , width = x .shape
806+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
807+ h = self .tiled_encode (x )
808+ else :
809+ h = self ._encode (x )
789810 posterior = DiagonalGaussianDistribution (h )
790811 if not return_dict :
791812 return (posterior ,)
@@ -826,12 +847,170 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826847 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827848 returned.
828849 """
829- decoded = self ._decode (z ).sample
850+ _ , _ , _ , height , width = z .shape
851+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
852+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
853+
854+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
855+ decoded = self .tiled_decode (z ).sample
856+ else :
857+ decoded = self ._decode (z ).sample
830858 if not return_dict :
831859 return (decoded ,)
832860
833861 return DecoderOutput (sample = decoded )
834862
863+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
864+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
865+ for y in range (blend_extent ):
866+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
867+ y / blend_extent
868+ )
869+ return b
870+
871+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
872+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
873+ for x in range (blend_extent ):
874+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
875+ x / blend_extent
876+ )
877+ return b
878+
879+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
880+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
881+ for x in range (blend_extent ):
882+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
883+ x / blend_extent
884+ )
885+ return b
886+
887+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
888+ r"""Encode a batch of images using a tiled encoder.
889+
890+ Args:
891+ x (`torch.Tensor`): Input batch of videos.
892+
893+ Returns:
894+ `torch.Tensor`:
895+ The latent representation of the encoded videos.
896+ """
897+ _ , _ , num_frames , height , width = x .shape
898+ latent_height = height // self .spatial_compression_ratio
899+ latent_width = width // self .spatial_compression_ratio
900+
901+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
902+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
903+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
904+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
905+
906+ blend_height = tile_latent_min_height - tile_latent_stride_height
907+ blend_width = tile_latent_min_width - tile_latent_stride_width
908+
909+ # Split x into overlapping tiles and encode them separately.
910+ # The tiles have an overlap to avoid seams between tiles.
911+ rows = []
912+ for i in range (0 , height , self .tile_sample_stride_height ):
913+ row = []
914+ for j in range (0 , width , self .tile_sample_stride_width ):
915+ self .clear_cache ()
916+ time = []
917+ frame_range = 1 + (num_frames - 1 ) // 4
918+ for k in range (frame_range ):
919+ self ._enc_conv_idx = [0 ]
920+ if k == 0 :
921+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
922+ else :
923+ tile = x [
924+ :,
925+ :,
926+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
927+ i : i + self .tile_sample_min_height ,
928+ j : j + self .tile_sample_min_width ,
929+ ]
930+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
931+ tile = self .quant_conv (tile )
932+ time .append (tile )
933+ row .append (torch .cat (time , dim = 2 ))
934+ rows .append (row )
935+
936+ result_rows = []
937+ for i , row in enumerate (rows ):
938+ result_row = []
939+ for j , tile in enumerate (row ):
940+ # blend the above tile and the left tile
941+ # to the current tile and add the current tile to the result row
942+ if i > 0 :
943+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
944+ if j > 0 :
945+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
946+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
947+ result_rows .append (torch .cat (result_row , dim = - 1 ))
948+
949+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
950+ return enc
951+
952+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
953+ r"""
954+ Decode a batch of images using a tiled decoder.
955+
956+ Args:
957+ z (`torch.Tensor`): Input batch of latent vectors.
958+ return_dict (`bool`, *optional*, defaults to `True`):
959+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
960+
961+ Returns:
962+ [`~models.vae.DecoderOutput`] or `tuple`:
963+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
964+ returned.
965+ """
966+ _ , _ , num_frames , height , width = z .shape
967+ sample_height = height * self .spatial_compression_ratio
968+ sample_width = width * self .spatial_compression_ratio
969+
970+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
971+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
972+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
973+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
974+
975+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
976+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
977+
978+ # Split z into overlapping tiles and decode them separately.
979+ # The tiles have an overlap to avoid seams between tiles.
980+ rows = []
981+ for i in range (0 , height , tile_latent_stride_height ):
982+ row = []
983+ for j in range (0 , width , tile_latent_stride_width ):
984+ self .clear_cache ()
985+ time = []
986+ for k in range (num_frames ):
987+ self ._conv_idx = [0 ]
988+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
989+ tile = self .post_quant_conv (tile )
990+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
991+ time .append (decoded )
992+ row .append (torch .cat (time , dim = 2 ))
993+ rows .append (row )
994+
995+ result_rows = []
996+ for i , row in enumerate (rows ):
997+ result_row = []
998+ for j , tile in enumerate (row ):
999+ # blend the above tile and the left tile
1000+ # to the current tile and add the current tile to the result row
1001+ if i > 0 :
1002+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1003+ if j > 0 :
1004+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1005+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1006+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1007+
1008+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1009+
1010+ if not return_dict :
1011+ return (dec ,)
1012+ return DecoderOutput (sample = dec )
1013+
8351014 def forward (
8361015 self ,
8371016 sample : torch .Tensor ,
0 commit comments