-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add VidTok AutoEncoders #11261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add VidTok AutoEncoders #11261
Conversation
|
Thank you for the PR @annitang1997! I will review this in depth soon. cc @yiyixuxu too |
|
Is there any updates on the review process? 👀 Looking forward to use VidTok with diffusers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR and congratulations for the release of your awesome work!
I did a first pass review about some changes that need to be made to make the implementation similar to remaining of the diffusers codebase. There are some core implementation details that will have to be refactored before we can merge. A good reference implementation for autoencoders can be found here:
- https://github.com/huggingface/diffusers/blob/0dec414d5bf2c7fe77684722b0a97324798bd7b3/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
- https://github.com/huggingface/diffusers/blob/0dec414d5bf2c7fe77684722b0a97324798bd7b3/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
I'd be happy to help assist in making some of these changes! 🤗
| return z_q | ||
|
|
||
|
|
||
| class FSQRegularizer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're moving towards maintaining a single file per modeling implementation, and so let's move this to the vidtok autoencoder file
src/diffusers/models/downsampling.py
Outdated
| return F.conv2d(inputs, weight, stride=2) | ||
|
|
||
|
|
||
| class VidTokDownsample2D(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
| return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] | ||
|
|
||
|
|
||
| class VidTokLayerNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
src/diffusers/models/upsampling.py
Outdated
| return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) | ||
|
|
||
|
|
||
| class VidTokUpsample2D(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from einops import pack, rearrange, unpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to replace all einops operations with permute/reshape/other ops since it adds another dependancy which we don't use in the codebase
|
|
||
| def create_custom_forward(module): | ||
| def custom_forward(*inputs): | ||
| return module.downsample(*inputs) | ||
|
|
||
| return custom_forward | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module.downsample(*inputs) | |
| return custom_forward |
| if i_level in self.spatial_ds: | ||
| # spatial downsample | ||
| htmp = rearrange(hs[-1], "b c t h w -> (b t) c h w") | ||
| htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp) | |
| htmp = self._gradient_checkpointing_func(self.down[i_level], htmp) |
| B, _, T, H, W = htmp.shape | ||
| # middle | ||
| h = hs[-1] | ||
| h = torch.utils.checkpoint.checkpoint(self.mid.block_1, h, temb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above for these usages
| return h | ||
|
|
||
|
|
||
| class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin): | |
| class AutoencoderVidTok(ModelMixin, ConfigMixin): |
| self.tile_overlap_factor_width = 0.0 # 1 / 8 | ||
|
|
||
| @staticmethod | ||
| def pad_at_dim( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any methods that are not to be directly invoked by users should be made private (that is prefix with an underscore _pad_at_dim)
|
Hello, I have improved the code based on your feedback. Please check it. 🤗 |
|
Any updates in this thread? :) |
|
@deeptimhe Sorry for the delay, I'm on leave at the moment, and so is @yiyixuxu. I'll try to test the PR and give it a look next week when I'm back |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR!
I left some feedbacks, one note on diffusers coding style is we try not to use too many small methods/functions. ideally all the logics are implemented in forward
I made a few examples in the review, if you can apply similar changes through out the implementation it would be great:)
| codes = codes.permute(0, -1, *range(1, codes.dim() - 1)) | ||
| return codes | ||
|
|
||
| @torch.cuda.amp.autocast(enabled=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you remove the autocast?
| self.global_codebook_usage = torch.zeros([2**self.codebook_dim, self.num_codebooks], dtype=torch.long) | ||
|
|
||
| @staticmethod | ||
| def default(*args) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this method and add default directly in the signature?
| self.num_codebooks = num_codebooks | ||
| self.effective_codebook_dim = effective_codebook_dim | ||
|
|
||
| keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1) | |
| if keep_num_codebooks_dim is None: | |
| keep_num_codebooks_dim = num_codebooks > 1 |
| self.effective_codebook_dim = effective_codebook_dim | ||
|
|
||
| keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1) | ||
| assert not (num_codebooks > 1 and not keep_num_codebooks_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert not (num_codebooks > 1 and not keep_num_codebooks_dim) |
| assert not (num_codebooks > 1 and not keep_num_codebooks_dim) | ||
| self.keep_num_codebooks_dim = keep_num_codebooks_dim | ||
|
|
||
| self.dim = self.default(dim, len(_levels) * num_codebooks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.dim = self.default(dim, len(_levels) * num_codebooks) | |
| self.dim = len(_levels) * num_codebooks if dim is None else dim |
| half_width = self._levels // 2 | ||
| return quantized / half_width | ||
|
|
||
| def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this method and move the code into codes_to_indices
| half_width = self._levels // 2 | ||
| return (zhat_normalized * half_width) + half_width | ||
|
|
||
| def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this method
| is_video = False | ||
| z = z.reshape(b, d, -1).permute(0, 2, 1) | ||
|
|
||
| assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" |
| self.cache_offset = 0 | ||
|
|
||
| @staticmethod | ||
| def _cast_tuple(t: Union[Tuple[int], int], length: int = 1) -> Tuple[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this method
| super().__init__() | ||
| self.pad_mode = pad_mode | ||
|
|
||
| kernel_size = self._cast_tuple(kernel_size, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kernel_size = self._cast_tuple(kernel_size, 3) | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size,) * 3 |
|
Hello, I have cleaned the code by removing small methods/functions based on your feedback. Please check it. 🤗 |
|
Any updates in this thread? :) |
We add VidTok, a versatile and state-of-the-art video tokenizer, as an autoencoder model to diffusers.
Paper: https://arxiv.org/pdf/2412.13061
Code: https://github.com/microsoft/VidTok
Model: https://huggingface.co/microsoft/VidTok