Skip to content

Commit f143b02

Browse files
committed
refactor
1 parent da53620 commit f143b02

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -700,13 +700,13 @@ def __init__(
700700
self.use_framewise_encoding = True
701701
self.use_framewise_decoding = True
702702

703-
# only relevant if vae tiling is enabled
704-
self.tile_sample_min_tsize = sample_tsize
705-
self.tile_latent_min_tsize = sample_tsize // temporal_compression_ratio
706703

707704
# The minimal tile height and width for spatial tiling to be used
708705
self.tile_sample_min_height = 256
709706
self.tile_sample_min_width = 256
707+
708+
# The minimal tile temporal batch size for temporal tiling to be used
709+
self.tile_sample_min_tsize = 64
710710

711711
# The minimal distance between two spatial tiles
712712
self.tile_sample_stride_height = 192
@@ -812,8 +812,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
812812
batch_size, num_channels, num_frames, height, width = z.shape
813813
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
814814
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
815+
tile_latent_min_num_frames = self.tile_sample_min_tsize // self.temporal_compression_ratio
815816

816-
if self.use_framewise_decoding and num_frames > self.tile_latent_min_tsize:
817+
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
817818
return self.temporal_tiled_decode(z, return_dict=return_dict)
818819

819820
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
@@ -987,9 +988,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
987988

988989
def temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
989990
B, C, T, H, W = x.shape
991+
tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio
990992
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
991-
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
992-
t_limit = self.tile_latent_min_tsize - blend_extent
993+
blend_extent = int(tile_latent_min_tsize * self.tile_overlap_factor)
994+
t_limit = tile_latent_min_tsize - blend_extent
993995

994996
# Split the video into tiles and encode them separately.
995997
row = []
@@ -1020,13 +1022,14 @@ def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Un
10201022
# Split z into overlapping tiles and decode them separately.
10211023

10221024
B, C, T, H, W = z.shape
1023-
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
1025+
tile_latent_min_tsize = self.tile_sample_min_tsize // self.temporal_compression_ratio
1026+
overlap_size = int(tile_latent_min_tsize * (1 - self.tile_overlap_factor))
10241027
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
10251028
t_limit = self.tile_sample_min_tsize - blend_extent
10261029

10271030
row = []
10281031
for i in range(0, T, overlap_size):
1029-
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
1032+
tile = z[:, :, i : i + tile_latent_min_tsize + 1, :, :]
10301033
if self.use_tiling and (
10311034
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
10321035
):

0 commit comments

Comments
 (0)