-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Implement framewise encoding/decoding in LTX Video VAE #10488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ec918b9
add framewise decode
64a0849
add framewise encode, refactor tiled encode/decode
rootonchair e79162c
add sanity test tiling for ltx
rootonchair b7d02c9
Merge branch 'main' into ltx_vae_framewise
rootonchair c5e6d62
run make style
rootonchair ab31c3b
Merge branch 'ltx_vae_framewise' of github.com:rootonchair/diffusers …
rootonchair 220fc78
Merge branch 'main' into ltx_vae_framewise
rootonchair 88bfc36
Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
rootonchair File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1010,10 +1010,12 @@ def __init__( | |
| # The minimal tile height and width for spatial tiling to be used | ||
| self.tile_sample_min_height = 512 | ||
| self.tile_sample_min_width = 512 | ||
| self.tile_sample_min_num_frames = 16 | ||
|
|
||
| # The minimal distance between two spatial tiles | ||
| self.tile_sample_stride_height = 448 | ||
| self.tile_sample_stride_width = 448 | ||
| self.tile_sample_stride_num_frames = 8 | ||
|
|
||
| def _set_gradient_checkpointing(self, module, value=False): | ||
| if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): | ||
|
|
@@ -1114,6 +1116,53 @@ def encode( | |
| if not return_dict: | ||
| return (posterior,) | ||
| return AutoencoderKLOutput(latent_dist=posterior) | ||
|
|
||
| def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
| blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) | ||
| for x in range(blend_extent): | ||
| b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( | ||
| x / blend_extent | ||
| ) | ||
| return b | ||
|
|
||
| def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: | ||
|
||
| batch_size, num_channels, num_frames, height, width = z.shape | ||
| num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 | ||
|
|
||
| tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
| tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
| tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio | ||
| tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio | ||
| blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames | ||
|
|
||
| row = [] | ||
| for i in range(0, num_frames, tile_latent_stride_num_frames): | ||
| tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] | ||
| if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): | ||
| decoded = self.tiled_decode(tile, temb, return_dict=True).sample | ||
| else: | ||
| print("NOT Use tile decode") | ||
| print(f"input tile: {tile.size()}") | ||
| decoded = self.decoder(tile, temb) | ||
| print(f"output tile: {decoded.size()}") | ||
| if i > 0: | ||
| decoded = decoded[:, :, :-1, :, :] | ||
| row.append(decoded) | ||
|
|
||
| result_row = [] | ||
| for i, tile in enumerate(row): | ||
| if i > 0: | ||
| tile = self.blend_t(row[i - 1], tile, blend_num_frames) | ||
| tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] | ||
| result_row.append(tile) | ||
| else: | ||
| result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :]) | ||
|
|
||
| dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] | ||
|
|
||
| if not return_dict: | ||
| return (dec,) | ||
| return DecoderOutput(sample=dec) | ||
|
|
||
| def _decode( | ||
| self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True | ||
|
|
@@ -1125,13 +1174,8 @@ def _decode( | |
| if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): | ||
| return self.tiled_decode(z, temb, return_dict=return_dict) | ||
|
|
||
| if self.use_framewise_decoding: | ||
| # TODO(aryan): requires investigation | ||
| raise NotImplementedError( | ||
| "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " | ||
| "quality issues caused by splitting inference across frame dimension. If you believe this " | ||
| "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." | ||
| ) | ||
| if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: | ||
| dec = self._temporal_tiled_decode(z, temb, return_dict=False)[0] | ||
| else: | ||
| dec = self.decoder(z, temb) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would move this down a few methods to where
blend_handblend_vare located