Skip to content

Conversation

@annitang1997
Copy link

We add VidTok, a versatile and state-of-the-art video tokenizer, as an autoencoder model to diffusers.

Paper: https://arxiv.org/pdf/2412.13061
Code: https://github.com/microsoft/VidTok
Model: https://huggingface.co/microsoft/VidTok

@a-r-r-o-w
Copy link
Contributor

Thank you for the PR @annitang1997! I will review this in depth soon. cc @yiyixuxu too

@deeptimhe
Copy link

deeptimhe commented Apr 20, 2025

Is there any updates on the review process? 👀 Looking forward to use VidTok with diffusers.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR and congratulations for the release of your awesome work!

I did a first pass review about some changes that need to be made to make the implementation similar to remaining of the diffusers codebase. There are some core implementation details that will have to be refactored before we can merge. A good reference implementation for autoencoders can be found here:

I'd be happy to help assist in making some of these changes! 🤗

return z_q


class FSQRegularizer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're moving towards maintaining a single file per modeling implementation, and so let's move this to the vidtok autoencoder file

return F.conv2d(inputs, weight, stride=2)


class VidTokDownsample2D(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]


class VidTokLayerNorm(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)


class VidTokUpsample2D(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, unpack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to replace all einops operations with permute/reshape/other ops since it adds another dependancy which we don't use in the codebase

Comment on lines 604 to 610

def create_custom_forward(module):
def custom_forward(*inputs):
return module.downsample(*inputs)

return custom_forward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_custom_forward(module):
def custom_forward(*inputs):
return module.downsample(*inputs)
return custom_forward

if i_level in self.spatial_ds:
# spatial downsample
htmp = rearrange(hs[-1], "b c t h w -> (b t) c h w")
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp)
htmp = self._gradient_checkpointing_func(self.down[i_level], htmp)

B, _, T, H, W = htmp.shape
# middle
h = hs[-1]
h = torch.utils.checkpoint.checkpoint(self.mid.block_1, h, temb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for these usages

return h


class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderVidTok(ModelMixin, ConfigMixin):

self.tile_overlap_factor_width = 0.0 # 1 / 8

@staticmethod
def pad_at_dim(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any methods that are not to be directly invoked by users should be made private (that is prefix with an underscore _pad_at_dim)

@annitang1997
Copy link
Author

annitang1997 commented May 9, 2025

Hello, I have improved the code based on your feedback. Please check it. 🤗

@deeptimhe
Copy link

Any updates in this thread? :)

@a-r-r-o-w
Copy link
Contributor

@deeptimhe Sorry for the delay, I'm on leave at the moment, and so is @yiyixuxu. I'll try to test the PR and give it a look next week when I'm back

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
I left some feedbacks, one note on diffusers coding style is we try not to use too many small methods/functions. ideally all the logics are implemented in forward

I made a few examples in the review, if you can apply similar changes through out the implementation it would be great:)

codes = codes.permute(0, -1, *range(1, codes.dim() - 1))
return codes

@torch.cuda.amp.autocast(enabled=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remove the autocast?

self.global_codebook_usage = torch.zeros([2**self.codebook_dim, self.num_codebooks], dtype=torch.long)

@staticmethod
def default(*args) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this method and add default directly in the signature?

self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim

keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1)
if keep_num_codebooks_dim is None:
keep_num_codebooks_dim = num_codebooks > 1

self.effective_codebook_dim = effective_codebook_dim

keep_num_codebooks_dim = self.default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)

assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim

self.dim = self.default(dim, len(_levels) * num_codebooks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.dim = self.default(dim, len(_levels) * num_codebooks)
self.dim = len(_levels) * num_codebooks if dim is None else dim

half_width = self._levels // 2
return quantized / half_width

def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this method and move the code into codes_to_indices

half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width

def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this method

is_video = False
z = z.reshape(b, d, -1).permute(0, 2, 1)

assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"

self.cache_offset = 0

@staticmethod
def _cast_tuple(t: Union[Tuple[int], int], length: int = 1) -> Tuple[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this method

super().__init__()
self.pad_mode = pad_mode

kernel_size = self._cast_tuple(kernel_size, 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kernel_size = self._cast_tuple(kernel_size, 3)
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3

@annitang1997
Copy link
Author

Hello, I have cleaned the code by removing small methods/functions based on your feedback. Please check it. 🤗

@annitang1997
Copy link
Author

Any updates in this thread? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants