diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index fbcb964392f9..046bcd00bc00 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import logging +from ...utils import deprecate, logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..downsampling import CogVideoXDownsample3D @@ -1086,9 +1086,23 @@ def __init__( self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. self.use_tiling = False + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not # recommended because the temporal parts of the VAE, here, are tricky to understand. # If you decode X latent frames together, the number of output frames is: @@ -1109,18 +1123,11 @@ def __init__( self.num_sample_frames_batch_size = 8 # We make the minimum height and width of sample for tiling half that of the generally supported - self.tile_sample_min_height = sample_height // 2 - self.tile_sample_min_width = sample_width // 2 - self.tile_latent_min_height = int( - self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) - ) - self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 - # These are experimental overlap factors that were chosen based on experimentation and seem to work best for - # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX - # and so the tiling implementation has only been tested on those specific resolutions. - self.tile_overlap_factor_height = 1 / 6 - self.tile_overlap_factor_width = 1 / 5 + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): @@ -1132,6 +1139,8 @@ def enable_tiling( tile_sample_min_width: Optional[int] = None, tile_overlap_factor_height: Optional[float] = None, tile_overlap_factor_width: Optional[float] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -1143,24 +1152,36 @@ def enable_tiling( The minimum height required for a sample to be separated into tiles across the height dimension. tile_sample_min_width (`int`, *optional*): The minimum width required for a sample to be separated into tiles across the width dimension. - tile_overlap_factor_height (`int`, *optional*): + tile_sample_stride_height (`int`, *optional*): The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are - no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. - tile_overlap_factor_width (`int`, *optional*): - The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there - are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher - value might cause more tiles to be processed leading to slow down of the decoding process. + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. """ + + if tile_overlap_factor_height is not None or tile_overlap_factor_width is not None: + deprecate( + "tile_overlap_factor", + "1.0.0", + "The parameters `tile_overlap_factor_height` and `tile_overlap_factor_width` are deprecated and will be ignored. Please use `tile_sample_stride_height` and `tile_sample_stride_width` instead. For now, we will use these flags automatically, if passed, without breaking the existing behaviour.", + ) + tile_sample_stride_height = ( + int((1 - tile_overlap_factor_height) * self.tile_sample_min_height) + // self.spatial_compression_ratio + * self.spatial_compression_ratio + ) + tile_sample_stride_width = ( + int((1 - tile_overlap_factor_width) * self.tile_sample_min_width) + // self.spatial_compression_ratio + * self.spatial_compression_ratio + ) + self.use_tiling = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width - self.tile_latent_min_height = int( - self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) - ) - self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) - self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height - self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width def disable_tiling(self) -> None: r""" @@ -1189,24 +1210,23 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) - frame_batch_size = self.num_sample_frames_batch_size - # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. - # As the extra single frame is handled inside the loop, it is not required to round up here. - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - enc = [] - - for i in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - x_intermediate = x[:, :, start_frame:end_frame] - x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + if self.use_framewise_encoding: + enc = [] + conv_cache = None + + for i in range(0, num_frames, self.num_sample_frames_batch_size): + x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size] + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + enc = torch.cat(enc, dim=2) + else: + enc, _ = self.encoder(x) if self.quant_conv is not None: - x_intermediate = self.quant_conv(x_intermediate) - enc.append(x_intermediate) + enc = self.quant_conv(enc) - enc = torch.cat(enc, dim=2) return enc @apply_forward_hook @@ -1239,26 +1259,28 @@ def encode( def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio - if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) - frame_batch_size = self.num_latent_frames_batch_size - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - dec = [] + if self.use_framewise_decoding: + dec = [] + conv_cache = None - for i in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - z_intermediate = z[:, :, start_frame:end_frame] - if self.post_quant_conv is not None: - z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) - dec.append(z_intermediate) + for i in range(0, num_frames, self.num_latent_frames_batch_size): + z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) - dec = torch.cat(dec, dim=2) + dec = torch.cat(dec, dim=2) + else: + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec, _ = self.decoder(z) if not return_dict: return (dec,) @@ -1324,44 +1346,48 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: """ # For a rough memory estimate, take a look at the `tiled_decode` method. batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio - overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_latent_min_height - blend_extent_height - row_limit_width = self.tile_latent_min_width - blend_extent_width - frame_batch_size = self.num_sample_frames_batch_size + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width # Split x into overlapping tiles and encode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, height, overlap_height): + for i in range(0, height, self.tile_sample_stride_height): row = [] - for j in range(0, width, overlap_width): - # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. - # As the extra single frame is handled inside the loop, it is not required to round up here. - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - time = [] - - for k in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = x[ - :, - :, - start_frame:end_frame, - i : i + self.tile_sample_min_height, - j : j + self.tile_sample_min_width, - ] - tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + time = [] + conv_cache = None + + for k in range(0, num_frames, self.num_sample_frames_batch_size): + tile = x[ + :, + :, + k : k + self.num_sample_frames_batch_size, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + + time = torch.cat(time, dim=2) + else: + tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + time, _ = self.encoder(tile) if self.quant_conv is not None: - tile = self.quant_conv(tile) - time.append(tile) + time = self.quant_conv(time) - row.append(torch.cat(time, dim=2)) + row.append(time) rows.append(row) result_rows = [] @@ -1371,13 +1397,13 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=4)) - enc = torch.cat(result_rows, dim=3) + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: @@ -1405,58 +1431,63 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, height, overlap_height): + for i in range(0, height, tile_latent_stride_height): row = [] - for j in range(0, width, overlap_width): - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - time = [] - - for k in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = z[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + time = [] + conv_cache = None + + for k in range(0, num_frames, self.num_latent_frames_batch_size): + tile = z[ + :, + :, + k : k + self.num_latent_frames_batch_size, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + time = torch.cat(time, dim=2) + else: + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] if self.post_quant_conv is not None: tile = self.post_quant_conv(tile) - tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) - time.append(tile) + time, _ = self.decoder(tile) - row.append(torch.cat(time, dim=2)) + row.append(time) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row + # Blend the above tile and the left tile to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=4)) - dec = torch.cat(result_rows, dim=3) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] if not return_dict: return (dec,) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 884ddfb2a95a..d80206eb62f8 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -268,8 +268,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): pipe.vae.enable_tiling( tile_sample_min_height=96, tile_sample_min_width=96, - tile_overlap_factor_height=1 / 12, - tile_overlap_factor_width=1 / 12, + tile_sample_stride_height=64, + tile_sample_stride_width=64, ) inputs = self.get_dummy_inputs(generator_device) inputs["height"] = inputs["width"] = 128 diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index 2a51fc65798c..efede9dee61c 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -272,8 +272,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.5): pipe.vae.enable_tiling( tile_sample_min_height=96, tile_sample_min_width=96, - tile_overlap_factor_height=1 / 12, - tile_overlap_factor_width=1 / 12, + tile_sample_stride_height=64, + tile_sample_stride_width=64, ) inputs = self.get_dummy_inputs(generator_device) inputs["height"] = inputs["width"] = 128 diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index f7e1fe7fd6c7..621ec9312714 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -291,8 +291,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.3): pipe.vae.enable_tiling( tile_sample_min_height=96, tile_sample_min_width=96, - tile_overlap_factor_height=1 / 12, - tile_overlap_factor_width=1 / 12, + tile_sample_stride_height=64, + tile_sample_stride_width=64, ) inputs = self.get_dummy_inputs(generator_device) inputs["height"] = inputs["width"] = 128 diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index 4d836cb5e2a4..173079dc2a80 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -273,8 +273,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): pipe.vae.enable_tiling( tile_sample_min_height=96, tile_sample_min_width=96, - tile_overlap_factor_height=1 / 12, - tile_overlap_factor_width=1 / 12, + tile_sample_stride_height=64, + tile_sample_stride_width=64, ) inputs = self.get_dummy_inputs(generator_device) inputs["height"] = inputs["width"] = 128