@@ -700,13 +700,13 @@ def __init__(
700700 self .use_framewise_encoding = True
701701 self .use_framewise_decoding = True
702702
703- # only relevant if vae tiling is enabled
704- self .tile_sample_min_tsize = sample_tsize
705- self .tile_latent_min_tsize = sample_tsize // temporal_compression_ratio
706703
707704 # The minimal tile height and width for spatial tiling to be used
708705 self .tile_sample_min_height = 256
709706 self .tile_sample_min_width = 256
707+
708+ # The minimal tile temporal batch size for temporal tiling to be used
709+ self .tile_sample_min_tsize = 64
710710
711711 # The minimal distance between two spatial tiles
712712 self .tile_sample_stride_height = 192
@@ -812,8 +812,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
812812 batch_size , num_channels , num_frames , height , width = z .shape
813813 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
814814 tile_latent_min_width = self .tile_sample_stride_width // self .spatial_compression_ratio
815+ tile_latent_min_num_frames = self .tile_sample_min_tsize // self .temporal_compression_ratio
815816
816- if self .use_framewise_decoding and num_frames > self . tile_latent_min_tsize :
817+ if self .use_framewise_decoding and num_frames > tile_latent_min_num_frames :
817818 return self .temporal_tiled_decode (z , return_dict = return_dict )
818819
819820 if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
@@ -987,9 +988,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
987988
988989 def temporal_tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
989990 B , C , T , H , W = x .shape
991+ tile_latent_min_tsize = self .tile_sample_min_tsize // self .temporal_compression_ratio
990992 overlap_size = int (self .tile_sample_min_tsize * (1 - self .tile_overlap_factor ))
991- blend_extent = int (self . tile_latent_min_tsize * self .tile_overlap_factor )
992- t_limit = self . tile_latent_min_tsize - blend_extent
993+ blend_extent = int (tile_latent_min_tsize * self .tile_overlap_factor )
994+ t_limit = tile_latent_min_tsize - blend_extent
993995
994996 # Split the video into tiles and encode them separately.
995997 row = []
@@ -1020,13 +1022,14 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un
10201022 # Split z into overlapping tiles and decode them separately.
10211023
10221024 B , C , T , H , W = z .shape
1023- overlap_size = int (self .tile_latent_min_tsize * (1 - self .tile_overlap_factor ))
1025+ tile_latent_min_tsize = self .tile_sample_min_tsize // self .temporal_compression_ratio
1026+ overlap_size = int (tile_latent_min_tsize * (1 - self .tile_overlap_factor ))
10241027 blend_extent = int (self .tile_sample_min_tsize * self .tile_overlap_factor )
10251028 t_limit = self .tile_sample_min_tsize - blend_extent
10261029
10271030 row = []
10281031 for i in range (0 , T , overlap_size ):
1029- tile = z [:, :, i : i + self . tile_latent_min_tsize + 1 , :, :]
1032+ tile = z [:, :, i : i + tile_latent_min_tsize + 1 , :, :]
10301033 if self .use_tiling and (
10311034 tile .shape [- 1 ] > self .tile_latent_min_size or tile .shape [- 2 ] > self .tile_latent_min_size
10321035 ):
0 commit comments