diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 7834206ddb4a..55a3764f129f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -117,6 +117,8 @@ def __init__( dilation=dilation, ) + self.return_conv_cache = True + def fake_context_parallel_forward( self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -128,7 +130,10 @@ def fake_context_parallel_forward( def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: inputs = self.fake_context_parallel_forward(inputs, conv_cache) - conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + if self.return_conv_cache: + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + else: + conv_cache = None padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) inputs = F.pad(inputs, padding_2d, mode="constant", value=0) @@ -1079,6 +1084,7 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.use_framewise_batching = True # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not # recommended because the temporal parts of the VAE, here, are tricky to understand. @@ -1174,6 +1180,20 @@ def disable_slicing(self) -> None: """ self.use_slicing = False + def enable_framewise_batching(self) -> None: + self.use_framewise_batching = True + + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.return_conv_cache = True + + def disable_framewise_batching(self) -> None: + self.use_framewise_batching = False + + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.return_conv_cache = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape @@ -1184,19 +1204,26 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 conv_cache = None - enc = [] - - for i in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - x_intermediate = x[:, :, start_frame:end_frame] - x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + + if self.use_framewise_batching: + enc = [] + + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + enc = torch.cat(enc, dim=2) + else: + enc, _ = self.encoder(x, conv_cache=conv_cache) if self.quant_conv is not None: - x_intermediate = self.quant_conv(x_intermediate) - enc.append(x_intermediate) + enc = self.quant_conv(enc) - enc = torch.cat(enc, dim=2) return enc @apply_forward_hook @@ -1236,19 +1263,25 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut frame_batch_size = self.num_latent_frames_batch_size num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None - dec = [] - for i in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) - end_frame = frame_batch_size * (i + 1) + remaining_frames - z_intermediate = z[:, :, start_frame:end_frame] - if self.post_quant_conv is not None: - z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) - dec.append(z_intermediate) + if self.use_framewise_batching: + dec = [] - dec = torch.cat(dec, dim=2) + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + dec = torch.cat(dec, dim=2) + else: + if self.post_quant_conv is not None: + dec = self.post_quant_conv(z) + dec, _ = self.decoder(z, conv_cache=conv_cache) if not return_dict: return (dec,)