Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


if is_torch_available():
from .enhance_a_video import EnhanceAVideoConfig, apply_enhance_a_video
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..models.attention_processor import Attention, MochiAttention


_ATTENTION_CLASSES = (Attention, MochiAttention)

_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
191 changes: 191 additions & 0 deletions src/diffusers/hooks/enhance_a_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from enum import Enum
from typing import Callable

import torch
import torch.overrides

from ..utils import get_logger
from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook


logger = get_logger(__name__)


_ENHANCE_A_VIDEO = "enhance_a_video"
_ENHANCE_A_VIDEO_SDPA = "enhance_a_video_sdpa"


class _AttentionType(Enum):
SELF = 0
JOINT___LATENTS_FIRST = 1
JOINT___LATENTS_LAST = 2


@dataclass
class EnhanceAVideoConfig:
r"""
Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508).
"""

weight: float = 1.0
num_frames_callback: Callable[[], int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be a function?

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... there's no easy way to determine this. Some models use dim=1 as frame dimension, whereas some do it in dim=2 (consider 5D tensor as the input going into transformer). Some models don't do this at all, for example LTX Video already flattens the FHW dimension before the transformer forward.

The information about number of latent frames is only available in the model transformer. Even then, sometimes it is modified by a patch embedding layer -- we don't know for sure, in general case, how to determine number of frames being used for inference.

In the Attention block where we attach hooks, the dimension of tensors are [B, S, D], we don't have access to this info either.

The only source for accurately getting this information is the user :( I'm open to suggestions and holding on to the PR for longer if we can figure out better way to do this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's fine,
I was just wondering why it is a function, not a constant

_attention_type: _AttentionType = _AttentionType.SELF
Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems only this is a config
weight and num_frames is more like runtime arguments, no? currently how do we update these for each generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently only support dynamically updating these values if user first removes all hooks by calling remove_enhance_a_video and then doing apply_enhance_a_video again. It's, uh, not really ideal but is a lightweight operation so we can get away with it.

Alternatively, to update dynamically, do you think we should do this:

  • when user calls apply_enhance_a_video, we return them some kind of handle object that has knowledge about the hooks
  • they can call a set_weight and set_frames method

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, perhaps we don't need the _attention_type argument. I can define a simple dictionary in the _common.py file that categorizes each attention processor into the three groups -- I think this is good info to have for some other methods that we could integrate soon

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove_enhance_a_video is fine I think,!
we can wait for a couple more use cases to decide how to support
(I don't like the set_weight and set_frame method because it's specific for each config name, I think we need something more generic)



class EnhanceAVideoAttentionState:
def __init__(self) -> None:
self.query: torch.Tensor = None
self.key: torch.Tensor = None
self.latents_sequence_length: int = None

def reset(self) -> None:
self.query = None
self.key = None
self.latents_sequence_length = None

def __repr__(self):
return f"EnhanceAVideoAttentionState(scores={self.scores}, latents_sequence_length={self.latents_sequence_length})"


class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode):
def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor], None]) -> None:
super().__init__()

self.query_key_save_callback = query_key_save_callback

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
query = kwargs.get("query", None) or args[0]
key = kwargs.get("key", None) or args[1]
self.query_key_save_callback(query, key)
return func(*args, **kwargs)


class EnhanceAVideoSDPAHook(ModelHook):
_is_stateful = True

def __init__(self, weight: float, num_frames_callback: Callable[[], int], _attention_type: _AttentionType) -> None:
self.weight = weight
self.num_frames_callback = num_frames_callback
self._attention_type = _attention_type

def initialize_hook(self, module):
self.state = EnhanceAVideoAttentionState()
return module

def new_forward(self, module, *args, **kwargs):
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
# Here, hidden_states could have three shapes (considering the general diffusers-style model implementation):

# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we start to automatically test and enforce this (make sure the case for all new models we implemented)?
OmniGen almost did not follow this, and it was not always easy to spot such things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I can add some tests soon to enforce this on any model that is added

# 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim]
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
Comment on lines +118 to +120
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
# 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim]
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
# 1. [batch_size, latents_sequence_length, embedding_dim]
# 2. [batch_size, latents_sequence_length + encoder_sequence_length, embedding_dim]
# 3. [batch_size, encoder_sequence_length + latents_sequence_length, embedding_dim]

kwargs_hidden_states = kwargs.get("hidden_states", None)
hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0]
self.state.latents_sequence_length = hidden_states.size(2)

# Capture query and key tensors to compute EnhanceAVideo scores
with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback):
return self.fn_ref.original_forward(*args, **kwargs)

def post_forward(self, module, output):
# For diffusers models, or ones that are implemented similar to our design, we either return:
# 1. A single output: `hidden_states`
# 2. A tuple of outputs: `(hidden_states, encoder_hidden_states)`.
# We need to handle both cases of applying EnhanceAVideo scores.
hidden_states = output[0] if isinstance(output, tuple) else output

def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor:
# This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations.
# [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here
can we start to enforce this?

return tensor.unflatten(2, (num_frames, -1)).permute(0, 3, 1, 2, 4).flatten(0, 1)

# Handle reshaping of query and key tensors
query, key = self.state.query, self.state.key
if self._attention_type == _AttentionType.SELF:
pass
elif self._attention_type == _AttentionType.JOINT___LATENTS_FIRST:
query = query[:, :, : self.state.latents_sequence_length]
key = key[:, :, : self.state.latents_sequence_length]
elif self._attention_type == _AttentionType.JOINT___LATENTS_LAST:
query = query[:, :, -self.state.latents_sequence_length :]
key = key[:, :, -self.state.latents_sequence_length :]

num_frames = self.num_frames_callback()
query = reshape_for_framewise_attention(query)
key = reshape_for_framewise_attention(key)
scores = enhance_a_video_score(query, key, num_frames, self.weight)
print("Applying scores:", scores)
hidden_states = hidden_states * scores

return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states

def reset_state(self, module):
self.state.reset()
return module

def _query_key_capture_callback(self, query: torch.Tensor, key: torch.Tensor) -> None:
self.state.query = query
self.state.key = key


def enhance_a_video_score(
query: torch.Tensor, key: torch.Tensor, num_frames: int, weight: float = 1.0
) -> torch.Tensor:
head_dim = query.size(-1)
scale = 1 / (head_dim**0.5)
query = query * scale

attn_temp = query @ key.transpose(-2, -1)
attn_temp = attn_temp.float()
attn_temp = attn_temp.softmax(dim=-1)

# Reshape to [batch_size * num_tokens, num_frames, num_frames]
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)

# Create a mask for diagonal elements
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.size(0), -1, -1)

# Zero out diagonal elements
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)

# Calculate mean for each token's attention matrix
# Number of off-diagonal elements per matrix is n*n - n
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

scores = mean_scores.mean() * (num_frames + weight)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should the mean be taken across all dimensions? I think it might be incorrect since each batch of data should have a different score due to different conditioning. Since we concatenate both unconditional and conditional branches and run batched inference, I believe this should be mean_scores.mean(list(range(1, mean_scores.ndim))). This will give us a tensor of shape (B,), which will also be compatible for multiplication and seems more correct to me

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice. We have tried this implementation before by calculating the mean score for each branch separately but found no obvious difference in the final output. As a result, we chose a more concise implementation by calculating the mean score together.

scores = scores.clamp(min=1)
return scores
Comment on lines +195 to +197
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yangluo7 @oahzxl scores here is always 1 with many different inputs that I tried. I've copied this part from the original implementation: https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video/blob/088d9e047b1738a45a253fd7cbe37fdf8526fb97/enhance_a_video/enhance.py

Am I doing something incorrect here or elsewhere? Thanks for your time!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w I noticed you set the enhance weight as 1 in "config = EnhanceAVideoConfig(weight=1.0, num_frames_callback=lambda: latent_num_frames, _attention_type=1)." Maybe it is too small to affect the final output. In our experiments, the weight is at least 5 for LTX-Video with the setting "width=768, height=512, num_frames=121." Thanks a lot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I see it taking effect now :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enhance weight is the only introduced parameter in our proposed method, it is affected by several factors including num_frames and prompts, so it needs to be further tuned based on them. We sincerely thank you for incorporating our method into diffusers, which makes our work more accessible to the community :)



def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None:
for name, submodule in module.named_modules():
is_cross_attention = getattr(submodule, "is_cross_attention", False)
if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention:
continue
logger.debug(f"Applying Enhance-A-Video to layer '{name}'")
hook_registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = EnhanceAVideoSDPAHook(
weight=config.weight,
num_frames_callback=config.num_frames_callback,
_attention_type=config._attention_type,
)
hook_registry.register_hook(hook, _ENHANCE_A_VIDEO_SDPA)
13 changes: 6 additions & 7 deletions src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@

from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from ._common import (
_ATTENTION_CLASSES,
_CROSS_ATTENTION_BLOCK_IDENTIFIERS,
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS,
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS,
)
from .hooks import HookRegistry, ModelHook


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


_ATTENTION_CLASSES = (Attention, MochiAttention)

_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")


@dataclass
class PyramidAttentionBroadcastConfig:
r"""
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from typing_extensions import Self

from .. import __version__
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
Expand Down Expand Up @@ -414,6 +413,7 @@ def enable_layerwise_casting(
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
from ..hooks import apply_layerwise_casting

user_provided_patterns = True
if skip_modules_pattern is None:
Expand Down Expand Up @@ -479,6 +479,8 @@ def enable_group_offload(
... )
```
"""
from ..hooks import apply_group_offloading

if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
Expand Down
Loading