-
Couldn't load subscription status.
- Fork 6.5k
Enhance-A-Video #10815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance-A-Video #10815
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from ..models.attention_processor import Attention, MochiAttention | ||
|
|
||
|
|
||
| _ATTENTION_CLASSES = (Attention, MochiAttention) | ||
|
|
||
| _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") | ||
| _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) | ||
| _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,199 @@ | ||||||||||||||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||||||||||||||
| # | ||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||
| # | ||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||
| # | ||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||
| # limitations under the License. | ||||||||||||||
|
|
||||||||||||||
| from dataclasses import dataclass | ||||||||||||||
| from enum import Enum | ||||||||||||||
| from typing import Callable | ||||||||||||||
|
|
||||||||||||||
| import torch | ||||||||||||||
| import torch.overrides | ||||||||||||||
|
|
||||||||||||||
| from ..utils import get_logger | ||||||||||||||
| from ._common import _ATTENTION_CLASSES | ||||||||||||||
| from .hooks import HookRegistry, ModelHook | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| logger = get_logger(__name__) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| _ENHANCE_A_VIDEO = "enhance_a_video" | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class _AttentionType(Enum): | ||||||||||||||
| SELF = 0 | ||||||||||||||
| JOINT___LATENTS_FIRST = 1 | ||||||||||||||
| JOINT___LATENTS_LAST = 2 | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @dataclass | ||||||||||||||
| class EnhanceAVideoConfig: | ||||||||||||||
| r""" | ||||||||||||||
| Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508). | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| weight: float = 1.0 | ||||||||||||||
| num_frames_callback: Callable[[], int] = None | ||||||||||||||
| _attention_type: _AttentionType = _AttentionType.SELF | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems only this is a config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We currently only support dynamically updating these values if user first removes all hooks by calling Alternatively, to update dynamically, do you think we should do this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, perhaps we don't need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove_enhance_a_video is fine I think,! |
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class EnhanceAVideoAttentionState: | ||||||||||||||
| def __init__(self) -> None: | ||||||||||||||
| self.query: torch.Tensor = None | ||||||||||||||
| self.key: torch.Tensor = None | ||||||||||||||
| self.latents_sequence_length: int = None | ||||||||||||||
|
|
||||||||||||||
| def reset(self) -> None: | ||||||||||||||
| self.query = None | ||||||||||||||
| self.key = None | ||||||||||||||
| self.latents_sequence_length = None | ||||||||||||||
|
|
||||||||||||||
| def __repr__(self): | ||||||||||||||
| return f"EnhanceAVideoAttentionState(scores={self.scores}, latents_sequence_length={self.latents_sequence_length})" | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode): | ||||||||||||||
| def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor], None]) -> None: | ||||||||||||||
| super().__init__() | ||||||||||||||
|
|
||||||||||||||
| self.query_key_save_callback = query_key_save_callback | ||||||||||||||
|
|
||||||||||||||
| def __torch_function__(self, func, types, args=(), kwargs=None): | ||||||||||||||
| if kwargs is None: | ||||||||||||||
| kwargs = {} | ||||||||||||||
| if func is torch.nn.functional.scaled_dot_product_attention: | ||||||||||||||
| query = kwargs.get("query", None) or args[0] | ||||||||||||||
| key = kwargs.get("key", None) or args[1] | ||||||||||||||
| self.query_key_save_callback(query, key) | ||||||||||||||
| return func(*args, **kwargs) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class EnhanceAVideoSDPAHook(ModelHook): | ||||||||||||||
| _is_stateful = True | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, weight: float, num_frames_callback: Callable[[], int], _attention_type: _AttentionType) -> None: | ||||||||||||||
| self.weight = weight | ||||||||||||||
| self.num_frames_callback = num_frames_callback | ||||||||||||||
| self._attention_type = _attention_type | ||||||||||||||
|
|
||||||||||||||
| def initialize_hook(self, module): | ||||||||||||||
| self.state = EnhanceAVideoAttentionState() | ||||||||||||||
| return module | ||||||||||||||
|
|
||||||||||||||
| def new_forward(self, module, *args, **kwargs): | ||||||||||||||
| # Here, query and key have two shapes (considering the general diffusers-style model implementation): | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| # 1. [batch_size, attention_heads, latents_sequence_length, head_dim] | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we start to automatically test and enforce this (make sure the case for all new models we implemented)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I can add some tests soon to enforce this on any model that is added |
||||||||||||||
| # 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim] | ||||||||||||||
| # 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim] | ||||||||||||||
|
Comment on lines
+118
to
+120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| kwargs_hidden_states = kwargs.get("hidden_states", None) | ||||||||||||||
| hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0] | ||||||||||||||
| self.state.latents_sequence_length = hidden_states.size(2) | ||||||||||||||
|
|
||||||||||||||
| # Capture query and key tensors to compute EnhanceAVideo scores | ||||||||||||||
| with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback): | ||||||||||||||
| return self.fn_ref.original_forward(*args, **kwargs) | ||||||||||||||
|
|
||||||||||||||
| def post_forward(self, module, output): | ||||||||||||||
| # For diffusers models, or ones that are implemented similar to our design, we either return: | ||||||||||||||
| # 1. A single output: `hidden_states` | ||||||||||||||
| # 2. A tuple of outputs: `(hidden_states, encoder_hidden_states)`. | ||||||||||||||
| # We need to handle both cases of applying EnhanceAVideo scores. | ||||||||||||||
| hidden_states = output[0] if isinstance(output, tuple) else output | ||||||||||||||
|
|
||||||||||||||
| def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor: | ||||||||||||||
| # This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations. | ||||||||||||||
| # [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C] | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||||||||||||||
| return tensor.unflatten(2, (num_frames, -1)).permute(0, 3, 1, 2, 4).flatten(0, 1) | ||||||||||||||
|
|
||||||||||||||
| # Handle reshaping of query and key tensors | ||||||||||||||
| query, key = self.state.query, self.state.key | ||||||||||||||
| if self._attention_type == _AttentionType.SELF: | ||||||||||||||
| pass | ||||||||||||||
| elif self._attention_type == _AttentionType.JOINT___LATENTS_FIRST: | ||||||||||||||
| query = query[:, :, : self.state.latents_sequence_length] | ||||||||||||||
| key = key[:, :, : self.state.latents_sequence_length] | ||||||||||||||
| elif self._attention_type == _AttentionType.JOINT___LATENTS_LAST: | ||||||||||||||
| query = query[:, :, -self.state.latents_sequence_length :] | ||||||||||||||
| key = key[:, :, -self.state.latents_sequence_length :] | ||||||||||||||
|
|
||||||||||||||
| num_frames = self.num_frames_callback() | ||||||||||||||
| query = reshape_for_framewise_attention(query) | ||||||||||||||
| key = reshape_for_framewise_attention(key) | ||||||||||||||
| scores = enhance_a_video_score(query, key, num_frames, self.weight) | ||||||||||||||
| print("Applying scores:", scores) | ||||||||||||||
| hidden_states = hidden_states * scores | ||||||||||||||
|
|
||||||||||||||
| return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states | ||||||||||||||
|
|
||||||||||||||
| def reset_state(self, module): | ||||||||||||||
| self.state.reset() | ||||||||||||||
| return module | ||||||||||||||
|
|
||||||||||||||
| def _query_key_capture_callback(self, query: torch.Tensor, key: torch.Tensor) -> None: | ||||||||||||||
| self.state.query = query | ||||||||||||||
| self.state.key = key | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def enhance_a_video_score( | ||||||||||||||
| query: torch.Tensor, key: torch.Tensor, num_frames: int, weight: float = 1.0 | ||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||
| head_dim = query.size(-1) | ||||||||||||||
| scale = 1 / (head_dim**0.5) | ||||||||||||||
| query = query * scale | ||||||||||||||
|
|
||||||||||||||
| attn_temp = query @ key.transpose(-2, -1) | ||||||||||||||
| attn_temp = attn_temp.float() | ||||||||||||||
| attn_temp = attn_temp.softmax(dim=-1) | ||||||||||||||
|
|
||||||||||||||
| # Reshape to [batch_size * num_tokens, num_frames, num_frames] | ||||||||||||||
| attn_temp = attn_temp.reshape(-1, num_frames, num_frames) | ||||||||||||||
|
|
||||||||||||||
| # Create a mask for diagonal elements | ||||||||||||||
| diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() | ||||||||||||||
| diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.size(0), -1, -1) | ||||||||||||||
|
|
||||||||||||||
| # Zero out diagonal elements | ||||||||||||||
| attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) | ||||||||||||||
|
|
||||||||||||||
| # Calculate mean for each token's attention matrix | ||||||||||||||
| # Number of off-diagonal elements per matrix is n*n - n | ||||||||||||||
| num_off_diag = num_frames * num_frames - num_frames | ||||||||||||||
| mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag | ||||||||||||||
|
|
||||||||||||||
| scores = mean_scores.mean() * (num_frames + weight) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, should the mean be taken across all dimensions? I think it might be incorrect since each batch of data should have a different score due to different conditioning. Since we concatenate both unconditional and conditional branches and run batched inference, I believe this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your advice. We have tried this implementation before by calculating the mean score for each branch separately but found no obvious difference in the final output. As a result, we chose a more concise implementation by calculating the mean score together. |
||||||||||||||
| scores = scores.clamp(min=1) | ||||||||||||||
| return scores | ||||||||||||||
|
Comment on lines
+195
to
+197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yangluo7 @oahzxl Am I doing something incorrect here or elsewhere? Thanks for your time! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @a-r-r-o-w I noticed you set the enhance weight as 1 in "config = EnhanceAVideoConfig(weight=1.0, num_frames_callback=lambda: latent_num_frames, _attention_type=1)." Maybe it is too small to affect the final output. In our experiments, the weight is at least 5 for LTX-Video with the setting "width=768, height=512, num_frames=121." Thanks a lot! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! I see it taking effect now :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The enhance weight is the only introduced parameter in our proposed method, it is affected by several factors including num_frames and prompts, so it needs to be further tuned based on them. We sincerely thank you for incorporating our method into diffusers, which makes our work more accessible to the community :) |
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None: | ||||||||||||||
| for name, submodule in module.named_modules(): | ||||||||||||||
| is_cross_attention = getattr(submodule, "is_cross_attention", False) | ||||||||||||||
| if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention: | ||||||||||||||
| continue | ||||||||||||||
| logger.debug(f"Applying Enhance-A-Video to layer '{name}'") | ||||||||||||||
| hook_registry = HookRegistry.check_if_exists_or_initialize(submodule) | ||||||||||||||
| hook = EnhanceAVideoSDPAHook( | ||||||||||||||
| weight=config.weight, | ||||||||||||||
| num_frames_callback=config.num_frames_callback, | ||||||||||||||
| _attention_type=config._attention_type, | ||||||||||||||
| ) | ||||||||||||||
| hook_registry.register_hook(hook, _ENHANCE_A_VIDEO) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def remove_enhance_a_video(module: torch.nn.Module) -> None: | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have a method to call when we want to remove all the model hooks on a model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not yet :( We can only remove hooks one at a time: diffusers/src/diffusers/hooks/hooks.py Line 179 in f8b54cf
I can add a method that allows removing all hooks if you'd like, LMK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm might be good to add a methods such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DN6 Sounds good. I think we should add those first, so will hold off merging here and open a PR for that first |
||||||||||||||
| for name, submodule in module.named_modules(): | ||||||||||||||
| if not hasattr(submodule, "_diffusers_hook"): | ||||||||||||||
| continue | ||||||||||||||
| hook_registry = submodule._diffusers_hook | ||||||||||||||
| hook_registry.remove_hook(_ENHANCE_A_VIDEO, recurse=False) | ||||||||||||||
| logger.debug(f"Removed Enhance-A-Video from layer '{name}'") | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this need to be a function?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So... there's no easy way to determine this. Some models use dim=1 as frame dimension, whereas some do it in dim=2 (consider 5D tensor as the input going into transformer). Some models don't do this at all, for example LTX Video already flattens the FHW dimension before the transformer forward.
The information about number of latent frames is only available in the model transformer. Even then, sometimes it is modified by a patch embedding layer -- we don't know for sure, in general case, how to determine number of frames being used for inference.
In the Attention block where we attach hooks, the dimension of tensors are
[B, S, D], we don't have access to this info either.The only source for accurately getting this information is the user :( I'm open to suggestions and holding on to the PR for longer if we can figure out better way to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's fine,
I was just wondering why it is a function, not a constant