@@ -568,11 +568,7 @@ def __init__(
568568 super ().__init__ ()
569569
570570 # 1. Input convolution
571- self .conv_in = EasyAnimateCausalConv3d (
572- in_channels ,
573- block_out_channels [- 1 ],
574- kernel_size = 3 ,
575- )
571+ self .conv_in = EasyAnimateCausalConv3d (in_channels , block_out_channels [- 1 ], kernel_size = 3 )
576572
577573 # 2. Middle block
578574 self .mid_block = EasyAnimateMidBlock3d (
@@ -734,21 +730,36 @@ def __init__(
734730 self .quant_conv = nn .Conv3d (2 * latent_channels , 2 * latent_channels , kernel_size = 1 )
735731 self .post_quant_conv = nn .Conv3d (latent_channels , latent_channels , kernel_size = 1 )
736732
737- # Assign mini-batch sizes for encoder and decoder
738- self .mini_batch_encoder = 4
739- self .mini_batch_decoder = 1
733+ self .spatial_compression_ratio = 2 ** (len (block_out_channels ) - 1 )
734+ self .temporal_compression_ratio = 2 ** (len (block_out_channels ) - 2 )
740735
741- # Initialize tiling and slicing flags
736+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737+ # to perform decoding of a single video latent at a time.
742738 self .use_slicing = False
739+
740+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742+ # intermediate tiles together, the memory requirement can be lowered.
743743 self .use_tiling = False
744744
745- # Set parameters for tiling if used
746- tile_overlap_factor = 0.25
747- self .tile_sample_min_size = 384
748- self .tile_overlap_factor = tile_overlap_factor
749- self .tile_latent_min_size = int (self .tile_sample_min_size / (2 ** (len (block_out_channels ) - 1 )))
750- # Assign the scaling factor for latent space
751- self .scaling_factor = scaling_factor
745+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
746+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
747+ self .use_framewise_encoding = False
748+ self .use_framewise_decoding = False
749+
750+ # Assign mini-batch sizes for encoder and decoder
751+ self .num_sample_frames_batch_size = 4
752+ self .num_latent_frames_batch_size = 1
753+
754+ # The minimal tile height and width for spatial tiling to be used
755+ self .tile_sample_min_height = 512
756+ self .tile_sample_min_width = 512
757+ self .tile_sample_min_num_frames = 4
758+
759+ # The minimal distance between two spatial tiles
760+ self .tile_sample_stride_height = 448
761+ self .tile_sample_stride_width = 448
762+ self .tile_sample_stride_num_frames = 8
752763
753764 def _clear_conv_cache (self ):
754765 # Clear cache for convolutional layers if needed
@@ -760,13 +771,39 @@ def _clear_conv_cache(self):
760771
761772 def enable_tiling (
762773 self ,
774+ tile_sample_min_height : Optional [int ] = None ,
775+ tile_sample_min_width : Optional [int ] = None ,
776+ tile_sample_min_num_frames : Optional [int ] = None ,
777+ tile_sample_stride_height : Optional [float ] = None ,
778+ tile_sample_stride_width : Optional [float ] = None ,
779+ tile_sample_stride_num_frames : Optional [float ] = None ,
763780 ) -> None :
764781 r"""
765782 Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
766783 compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
767784 processing larger images.
785+
786+ Args:
787+ tile_sample_min_height (`int`, *optional*):
788+ The minimum height required for a sample to be separated into tiles across the height dimension.
789+ tile_sample_min_width (`int`, *optional*):
790+ The minimum width required for a sample to be separated into tiles across the width dimension.
791+ tile_sample_stride_height (`int`, *optional*):
792+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
793+ no tiling artifacts produced across the height dimension.
794+ tile_sample_stride_width (`int`, *optional*):
795+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
796+ artifacts produced across the width dimension.
768797 """
769798 self .use_tiling = True
799+ self .use_framewise_decoding = True
800+ self .use_framewise_encoding = True
801+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
802+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
803+ self .tile_sample_min_num_frames = tile_sample_min_num_frames or self .tile_sample_min_num_frames
804+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
805+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
806+ self .tile_sample_stride_num_frames = tile_sample_stride_num_frames or self .tile_sample_stride_num_frames
770807
771808 def disable_tiling (self ) -> None :
772809 r"""
@@ -805,14 +842,13 @@ def _encode(
805842 The latent representations of the encoded images. If `return_dict` is True, a
806843 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
807844 """
808- if self .use_tiling and (x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size ):
809- x = self .tiled_encode (x , return_dict = return_dict )
810- return x
845+ if self .use_tiling and (x .shape [- 1 ] > self .tile_sample_min_height or x .shape [- 2 ] > self .tile_sample_min_width ):
846+ return self .tiled_encode (x , return_dict = return_dict )
811847
812- first_frames = self .encoder (x [:, :, 0 :1 , :, :])
848+ first_frames = self .encoder (x [:, :, :1 , :, :])
813849 h = [first_frames ]
814- for i in range (1 , x .shape [2 ], self .mini_batch_encoder ):
815- next_frames = self .encoder (x [:, :, i : i + self .mini_batch_encoder , :, :])
850+ for i in range (1 , x .shape [2 ], self .num_sample_frames_batch_size ):
851+ next_frames = self .encoder (x [:, :, i : i + self .num_sample_frames_batch_size , :, :])
816852 h .append (next_frames )
817853 h = torch .cat (h , dim = 2 )
818854 moments = self .quant_conv (h )
@@ -849,18 +885,22 @@ def encode(
849885 return AutoencoderKLOutput (latent_dist = posterior )
850886
851887 def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
852- if self .use_tiling and (z .shape [- 1 ] > self .tile_latent_min_size or z .shape [- 2 ] > self .tile_latent_min_size ):
888+ batch_size , num_channels , num_frames , height , width = z .shape
889+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
890+ tile_latent_min_width = self .tile_sample_stride_width // self .spatial_compression_ratio
891+
892+ if self .use_tiling and (z .shape [- 1 ] > tile_latent_min_height or z .shape [- 2 ] > tile_latent_min_width ):
853893 return self .tiled_decode (z , return_dict = return_dict )
854894
855895 z = self .post_quant_conv (z )
856896
857897 # Process the first frame and save the result
858- first_frames = self .decoder (z [:, :, 0 :1 , :, :])
898+ first_frames = self .decoder (z [:, :, :1 , :, :])
859899 # Initialize the list to store the processed frames, starting with the first frame
860900 dec = [first_frames ]
861901 # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
862- for i in range (1 , z .shape [2 ], self .mini_batch_decoder ):
863- next_frames = self .decoder (z [:, :, i : i + self .mini_batch_decoder , :, :])
902+ for i in range (1 , z .shape [2 ], self .num_latent_frames_batch_size ):
903+ next_frames = self .decoder (z [:, :, i : i + self .num_latent_frames_batch_size , :, :])
864904 dec .append (next_frames )
865905 # Concatenate all processed frames along the channel dimension
866906 dec = torch .cat (dec , dim = 2 )
@@ -913,27 +953,35 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
913953 return b
914954
915955 def tiled_encode (self , x : torch .Tensor , return_dict : bool = True ) -> AutoencoderKLOutput :
916- overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
917- blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
918- row_limit = self .tile_latent_min_size - blend_extent
956+ batch_size , num_channels , num_frames , height , width = x .shape
957+ latent_height = height // self .spatial_compression_ratio
958+ latent_width = width // self .spatial_compression_ratio
959+
960+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
961+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
962+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
963+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
964+
965+ blend_height = tile_latent_min_height - tile_latent_stride_height
966+ blend_width = tile_latent_min_width - tile_latent_stride_width
919967
920968 # Split the image into 512x512 tiles and encode them separately.
921969 rows = []
922- for i in range (0 , x . shape [ 3 ], overlap_size ):
970+ for i in range (0 , height , self . tile_sample_stride_height ):
923971 row = []
924- for j in range (0 , x . shape [ 4 ], overlap_size ):
972+ for j in range (0 , width , self . tile_sample_stride_width ):
925973 tile = x [
926974 :,
927975 :,
928976 :,
929- i : i + self .tile_sample_min_size ,
930- j : j + self .tile_sample_min_size ,
977+ i : i + self .tile_sample_min_height ,
978+ j : j + self .tile_sample_min_width ,
931979 ]
932980
933981 first_frames = self .encoder (tile [:, :, 0 :1 , :, :])
934982 tile_h = [first_frames ]
935- for frame_index in range (1 , tile . shape [ 2 ] , self .mini_batch_encoder ):
936- next_frames = self .encoder (tile [:, :, frame_index : frame_index + self .mini_batch_encoder , :, :])
983+ for k in range (1 , num_frames , self .num_sample_frames_batch_size ):
984+ next_frames = self .encoder (tile [:, :, k : k + self .num_sample_frames_batch_size , :, :])
937985 tile_h .append (next_frames )
938986 tile = torch .cat (tile_h , dim = 2 )
939987 tile = self .quant_conv (tile )
@@ -947,42 +995,50 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
947995 # blend the above tile and the left tile
948996 # to the current tile and add the current tile to the result row
949997 if i > 0 :
950- tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
998+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
951999 if j > 0 :
952- tile = self .blend_h (row [j - 1 ], tile , blend_extent )
953- result_row .append (tile [:, :, :, :row_limit , :row_limit ])
1000+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1001+ result_row .append (tile [:, :, :, :latent_height , :latent_width ])
9541002 result_rows .append (torch .cat (result_row , dim = 4 ))
9551003
956- moments = torch .cat (result_rows , dim = 3 )
1004+ moments = torch .cat (result_rows , dim = 3 )[:, :, :, : latent_height , : latent_width ]
9571005 return moments
9581006
9591007 def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
960- overlap_size = int (self .tile_latent_min_size * (1 - self .tile_overlap_factor ))
961- blend_extent = int (self .tile_sample_min_size * self .tile_overlap_factor )
962- row_limit = self .tile_sample_min_size - blend_extent
1008+ batch_size , num_channels , num_frames , height , width = z .shape
1009+ sample_height = height * self .spatial_compression_ratio
1010+ sample_width = width * self .spatial_compression_ratio
1011+
1012+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1013+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1014+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
1015+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1016+
1017+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1018+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
9631019
9641020 # Split z into overlapping 64x64 tiles and decode them separately.
9651021 # The tiles have an overlap to avoid seams between tiles.
9661022 rows = []
967- for i in range (0 , z . shape [ 3 ], overlap_size ):
1023+ for i in range (0 , height , tile_latent_stride_height ):
9681024 row = []
969- for j in range (0 , z . shape [ 4 ], overlap_size ):
1025+ for j in range (0 , width , tile_latent_stride_width ):
9701026 tile = z [
9711027 :,
9721028 :,
9731029 :,
974- i : i + self . tile_latent_min_size ,
975- j : j + self . tile_latent_min_size ,
1030+ i : i + tile_latent_min_height ,
1031+ j : j + tile_latent_min_width ,
9761032 ]
9771033 tile = self .post_quant_conv (tile )
9781034
9791035 # Process the first frame and save the result
980- first_frames = self .decoder (tile [:, :, 0 :1 , :, :])
1036+ first_frames = self .decoder (tile [:, :, :1 , :, :])
9811037 # Initialize the list to store the processed frames, starting with the first frame
9821038 tile_dec = [first_frames ]
9831039 # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
984- for frame_index in range (1 , tile . shape [ 2 ] , self .mini_batch_decoder ):
985- next_frames = self .decoder (tile [:, :, frame_index : frame_index + self .mini_batch_decoder , :, :])
1040+ for k in range (1 , num_frames , self .num_latent_frames_batch_size ):
1041+ next_frames = self .decoder (tile [:, :, k : k + self .num_latent_frames_batch_size , :, :])
9861042 tile_dec .append (next_frames )
9871043 # Concatenate all processed frames along the channel dimension
9881044 decoded = torch .cat (tile_dec , dim = 2 )
@@ -996,13 +1052,13 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
9961052 # blend the above tile and the left tile
9971053 # to the current tile and add the current tile to the result row
9981054 if i > 0 :
999- tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
1055+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
10001056 if j > 0 :
1001- tile = self .blend_h (row [j - 1 ], tile , blend_extent )
1002- result_row .append (tile [:, :, :, :row_limit , :row_limit ])
1057+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1058+ result_row .append (tile [:, :, :, : self . tile_sample_stride_height , : self . tile_sample_stride_width ])
10031059 result_rows .append (torch .cat (result_row , dim = 4 ))
10041060
1005- dec = torch .cat (result_rows , dim = 3 )
1061+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, : sample_height , : sample_width ]
10061062
10071063 if not return_dict :
10081064 return (dec ,)
0 commit comments