Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Feb 18, 2025

Adds support for Enhance-A-Video.

Paper: https://huggingface.co/papers/2502.07508
Project: https://oahzxl.github.io/Enhance_A_Video
Code: https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video

The PR needs some rework in terms of user-facing API design. I'll need some reviews to gather thoughts on how best to implement this and make available with most, if not all, diffusers or diffusers-like video model implementations.

Currently, I've only tested with LTX Video.

import torch
from diffusers import LTXPipeline
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

num_frames = 161
latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
config = EnhanceAVideoConfig(weight=1.0, num_frames_callback=lambda: latent_num_frames, _attention_type=1)
apply_enhance_a_video(pipe.transformer, config)

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=704,
    height=480,
    num_frames=161,
    num_inference_steps=50,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_video(video, "output.mp4", fps=24)

The intended effect of Enhance-A-Video does not seem to be applied yet, as outputs with & without it are the same. I did some quick debugging and it seems like the enhance scores are always 1. This leads to no effect on the hidden_states * scores that are returned from the attention block. Will need to investigate with authors if I'm doing something wrong.

cc @yangluo7 @oahzxl @kaiwang960112

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu February 18, 2025 00:34
Comment on lines +174 to +176
scores = mean_scores.mean() * (num_frames + weight)
scores = scores.clamp(min=1)
return scores
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yangluo7 @oahzxl scores here is always 1 with many different inputs that I tried. I've copied this part from the original implementation: https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video/blob/088d9e047b1738a45a253fd7cbe37fdf8526fb97/enhance_a_video/enhance.py

Am I doing something incorrect here or elsewhere? Thanks for your time!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w I noticed you set the enhance weight as 1 in "config = EnhanceAVideoConfig(weight=1.0, num_frames_callback=lambda: latent_num_frames, _attention_type=1)." Maybe it is too small to affect the final output. In our experiments, the weight is at least 5 for LTX-Video with the setting "width=768, height=512, num_frames=121." Thanks a lot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I see it taking effect now :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enhance weight is the only introduced parameter in our proposed method, it is affected by several factors including num_frames and prompts, so it needs to be further tuned based on them. We sincerely thank you for incorporating our method into diffusers, which makes our work more accessible to the community :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

scores = mean_scores.mean() * (num_frames + weight)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should the mean be taken across all dimensions? I think it might be incorrect since each batch of data should have a different score due to different conditioning. Since we concatenate both unconditional and conditional branches and run batched inference, I believe this should be mean_scores.mean(list(range(1, mean_scores.ndim))). This will give us a tensor of shape (B,), which will also be compatible for multiplication and seems more correct to me

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice. We have tried this implementation before by calculating the mean score for each branch separately but found no obvious difference in the final output. As a result, we chose a more concise implementation by calculating the mean score together.

@tin2tin
Copy link

tin2tin commented Feb 19, 2025

I tested this. The first thing to note is that the models isn't released from memory when the inference has finished.

I used this prompt, captioned by chatGPT, imitating the style of the prompt above:

A young girl with curly brown hair sits on the floor, wearing a bright yellow dress. A small yellow hair clip holds back some of her curls. She is drawing on a large piece of cardboard propped up like an easel, creating a colorful scene of flowers, grass, and a sun with markers. Sunlight streams through a window behind her, casting a warm glow on the wooden floor. Her bare feet are tucked under her as she leans forward, focused on her artwork. Various markers are scattered around her. The setting appears to be a cozy living room with soft, natural lighting. The scene appears to be real-life footage.

I used the code above, but changed the weight value.

This is with no enhance:

girl_output.mp4

Weight: 2

girl_enhance2_output.mp4

Weight: 3

girl_enhance3_output.mp4

Weight: 5

girl_enhance5_output.mp4

Weight: 7

girl_enhance7_output.mp4

Weight: 8.5

girl_enhance8.5_output.mp4

Weight: 10

girl_enhance10_output.mp4

Weight: 15

girl_enhance15_output.mp4

They all have deformities, so it is kind of hard to conclude anything on (is between 7 and 10 best? But the colors start to fry at 15?), but maybe it's because of the LTX video being wonky from the start (personally, I find it very hard to get LTX to produce anything without severe deformities, when not using the default prompt). So, it is clear with the current implementation, that the enhance function is doing something, but it is not super clear to me if it is enough to save poor input material? Would it clean up the deformities it run twice?

@yangluo7
Copy link

Thanks for the testing. Firstly, Enhance-A-Video improves the generated video's quality based on the foundation models' existing attention weights and makes moderate adjustments in the residual connection, so the final video quality still relies on the existing generative quality of the foundation model itself. If the original generated video quality is quite low, it is hard to generate an ideal video without improving the pre-training phase. Secondly, we can find that the video quality is improved with weights between 8-10, which demonstrates the effectiveness of Enhance-A-Video.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

design looks good to me
I think we should start to add test to enforce certain styles now we want to expand and encourage the usage of hooks

"""

weight: float = 1.0
num_frames_callback: Callable[[], int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be a function?

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... there's no easy way to determine this. Some models use dim=1 as frame dimension, whereas some do it in dim=2 (consider 5D tensor as the input going into transformer). Some models don't do this at all, for example LTX Video already flattens the FHW dimension before the transformer forward.

The information about number of latent frames is only available in the model transformer. Even then, sometimes it is modified by a patch embedding layer -- we don't know for sure, in general case, how to determine number of frames being used for inference.

In the Attention block where we attach hooks, the dimension of tensors are [B, S, D], we don't have access to this info either.

The only source for accurately getting this information is the user :( I'm open to suggestions and holding on to the PR for longer if we can figure out better way to do this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's fine,
I was just wondering why it is a function, not a constant


def new_forward(self, module, *args, **kwargs):
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we start to automatically test and enforce this (make sure the case for all new models we implemented)?
OmniGen almost did not follow this, and it was not always easy to spot such things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I can add some tests soon to enforce this on any model that is added


def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor:
# This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations.
# [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here
can we start to enforce this?


weight: Union[float, Dict[str, float]] = 1.0
num_frames_callback: Callable[[], int] = None
_attention_type: _AttentionType = _AttentionType.SELF
Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems only this is a config
weight and num_frames is more like runtime arguments, no? currently how do we update these for each generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently only support dynamically updating these values if user first removes all hooks by calling remove_enhance_a_video and then doing apply_enhance_a_video again. It's, uh, not really ideal but is a lightweight operation so we can get away with it.

Alternatively, to update dynamically, do you think we should do this:

  • when user calls apply_enhance_a_video, we return them some kind of handle object that has knowledge about the hooks
  • they can call a set_weight and set_frames method

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, perhaps we don't need the _attention_type argument. I can define a simple dictionary in the _common.py file that categorizes each attention processor into the three groups -- I think this is good info to have for some other methods that we could integrate soon

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove_enhance_a_video is fine I think,!
we can wait for a couple more use cases to decide how to support
(I don't like the set_weight and set_frame method because it's specific for each config name, I think we need something more generic)

@a-r-r-o-w
Copy link
Contributor Author

@tin2tin Our LTX implementation is missing a few of the latest features implemented in the original repository to improve generation quality. This will be improved soon as I find some time to work on it. I added support for adding weight factor per block (you can specify a dictionary of regex pattern mapping to weight values), so you can play around a bit and see what layers are best suited for applying the method -- from my testing, applying on blocks between 5-20 seem to work best and do not modify predictions too much

@a-r-r-o-w
Copy link
Contributor Author

@yangluo7 Would you be able to give a final review as well for correctness check? The implementation will not change much after my latest commit, apart from docs/tests, so is more or less finalized. Thanks!

Comment on lines +118 to +120
# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
# 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim]
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
# 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim]
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
# 1. [batch_size, latents_sequence_length, embedding_dim]
# 2. [batch_size, latents_sequence_length + encoder_sequence_length, embedding_dim]
# 3. [batch_size, encoder_sequence_length + latents_sequence_length, embedding_dim]

return module

def new_forward(self, module, *args, **kwargs):
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
# Here, hidden_states could have three shapes (considering the general diffusers-style model implementation):

hook_registry.register_hook(hook, _ENHANCE_A_VIDEO)


def remove_enhance_a_video(module: torch.nn.Module) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a method to call when we want to remove all the model hooks on a model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet :(

We can only remove hooks one at a time:

def remove_hook(self, name: str, recurse: bool = True) -> None:
if the user knows the name (can be found by printing the registry)

I can add a method that allows removing all hooks if you'd like, LMK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm might be good to add a methods such as
enable_hook, disable_hook, disable_all_hooks etc to ModelMixin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Sounds good. I think we should add those first, so will hold off merging here and open a PR for that first

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 22, 2025
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Mar 24, 2025
@a-r-r-o-w a-r-r-o-w closed this Jul 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants