Skip to content

Commit dbae971

Browse files
committed
update
1 parent c14057c commit dbae971

File tree

4 files changed

+206
-7
lines changed

4 files changed

+206
-7
lines changed

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
if is_torch_available():
5+
from .enhance_a_video import EnhanceAVideoConfig, apply_enhance_a_video
56
from .group_offloading import apply_group_offloading
67
from .hooks import HookRegistry, ModelHook
78
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from ..models.attention_processor import Attention, MochiAttention
2+
3+
4+
_ATTENTION_CLASSES = (Attention, MochiAttention)
5+
6+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
7+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
8+
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from enum import Enum
17+
from typing import Callable
18+
19+
import torch
20+
import torch.overrides
21+
22+
from ..utils import get_logger
23+
from ._common import _ATTENTION_CLASSES
24+
from .hooks import HookRegistry, ModelHook
25+
26+
27+
logger = get_logger(__name__)
28+
29+
30+
_ENHANCE_A_VIDEO = "enhance_a_video"
31+
_ENHANCE_A_VIDEO_SDPA = "enhance_a_video_sdpa"
32+
33+
34+
class _AttentionType(Enum):
35+
SELF = 0
36+
JOINT___LATENTS_FIRST = 1
37+
JOINT___LATENTS_LAST = 2
38+
39+
40+
@dataclass
41+
class EnhanceAVideoConfig:
42+
r"""
43+
Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508).
44+
"""
45+
46+
weight: float = 1.0
47+
num_frames_callback: Callable[[], int] = None
48+
_attention_type: _AttentionType = _AttentionType.SELF
49+
50+
51+
class EnhanceAVideoAttentionState:
52+
def __init__(self) -> None:
53+
self.query: torch.Tensor = None
54+
self.key: torch.Tensor = None
55+
self.latents_sequence_length: int = None
56+
57+
def reset(self) -> None:
58+
self.query = None
59+
self.key = None
60+
self.latents_sequence_length = None
61+
62+
def __repr__(self):
63+
return f"EnhanceAVideoAttentionState(scores={self.scores}, latents_sequence_length={self.latents_sequence_length})"
64+
65+
66+
class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode):
67+
def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor], None]) -> None:
68+
super().__init__()
69+
70+
self.query_key_save_callback = query_key_save_callback
71+
72+
def __torch_function__(self, func, types, args=(), kwargs=None):
73+
if kwargs is None:
74+
kwargs = {}
75+
if func is torch.nn.functional.scaled_dot_product_attention:
76+
query = kwargs.get("query", None) or args[0]
77+
key = kwargs.get("key", None) or args[1]
78+
self.query_key_save_callback(query, key)
79+
return func(*args, **(kwargs or {}))
80+
81+
82+
class EnhanceAVideoSDPAHook(ModelHook):
83+
_is_stateful = True
84+
85+
def __init__(self, weight: float, num_frames_callback: Callable[[], int], _attention_type: _AttentionType) -> None:
86+
self.weight = weight
87+
self.num_frames_callback = num_frames_callback
88+
self._attention_type = _attention_type
89+
90+
def initialize_hook(self, module):
91+
self.state = EnhanceAVideoAttentionState()
92+
return module
93+
94+
def new_forward(self, module, *args, **kwargs):
95+
# Here, query and key have two shapes (considering the general diffusers-style model implementation):
96+
# 1. [batch_size, attention_heads, latents_sequence_length, head_dim]
97+
# 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim]
98+
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
99+
kwargs_hidden_states = kwargs.get("hidden_states", None)
100+
hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0]
101+
self.state.latents_sequence_length = hidden_states.size(2)
102+
103+
# Capture query and key tensors to compute EnhanceAVideo scores
104+
with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback):
105+
return self.fn_ref.original_forward(*args, **kwargs)
106+
107+
def post_forward(self, module, output):
108+
# For diffusers models, or ones that are implemented similar to our design, we either return:
109+
# 1. A single output: `hidden_states`
110+
# 2. A tuple of outputs: `(hidden_states, encoder_hidden_states)`.
111+
# We need to handle both cases of applying EnhanceAVideo scores.
112+
hidden_states = output[0] if isinstance(output, tuple) else output
113+
114+
def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor:
115+
# This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations.
116+
# [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C]
117+
return tensor.unflatten(2, (num_frames, -1)).permute(0, 3, 1, 2, 4).flatten(0, 1)
118+
119+
# Handle reshaping of query and key tensors
120+
query, key = self.state.query, self.state.key
121+
if self._attention_type == _AttentionType.SELF:
122+
pass
123+
elif self._attention_type == _AttentionType.JOINT___LATENTS_FIRST:
124+
query = query[:, :, : self.state.latents_sequence_length]
125+
key = key[:, :, : self.state.latents_sequence_length]
126+
elif self._attention_type == _AttentionType.JOINT___LATENTS_LAST:
127+
query = query[:, :, -self.state.latents_sequence_length :]
128+
key = key[:, :, -self.state.latents_sequence_length :]
129+
130+
num_frames = self.num_frames_callback()
131+
query = reshape_for_framewise_attention(query)
132+
key = reshape_for_framewise_attention(key)
133+
scores = enhance_a_video_score(query, key, num_frames, self.weight)
134+
print("Applying scores:", scores)
135+
hidden_states = hidden_states * scores
136+
137+
return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states
138+
139+
def reset_state(self, module):
140+
self.state.reset()
141+
return module
142+
143+
def _query_key_capture_callback(self, query: torch.Tensor, key: torch.Tensor) -> None:
144+
self.state.query = query
145+
self.state.key = key
146+
147+
148+
def enhance_a_video_score(
149+
query: torch.Tensor, key: torch.Tensor, num_frames: int, weight: float = 1.0
150+
) -> torch.Tensor:
151+
head_dim = query.size(-1)
152+
scale = 1 / (head_dim**0.5)
153+
query = query * scale
154+
155+
attn_temp = query @ key.transpose(-2, -1)
156+
attn_temp = attn_temp.float()
157+
attn_temp = attn_temp.softmax(dim=-1)
158+
159+
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
160+
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
161+
162+
# Create a mask for diagonal elements
163+
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
164+
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.size(0), -1, -1)
165+
166+
# Zero out diagonal elements
167+
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
168+
169+
# Calculate mean for each token's attention matrix
170+
# Number of off-diagonal elements per matrix is n*n - n
171+
num_off_diag = num_frames * num_frames - num_frames
172+
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
173+
174+
scores = mean_scores.mean() * (num_frames + weight)
175+
scores = scores.clamp(min=1)
176+
return scores
177+
178+
179+
def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None:
180+
for name, submodule in module.named_modules():
181+
is_cross_attention = getattr(submodule, "is_cross_attention", False)
182+
if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention:
183+
continue
184+
logger.debug(f"Applying Enhance-A-Video to layer '{name}'")
185+
hook_registry = HookRegistry.check_if_exists_or_initialize(submodule)
186+
hook = EnhanceAVideoSDPAHook(
187+
weight=config.weight,
188+
num_frames_callback=config.num_frames_callback,
189+
_attention_type=config._attention_type,
190+
)
191+
hook_registry.register_hook(hook, _ENHANCE_A_VIDEO_SDPA)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@
2020

2121
from ..models.attention_processor import Attention, MochiAttention
2222
from ..utils import logging
23+
from ._common import (
24+
_ATTENTION_CLASSES,
25+
_CROSS_ATTENTION_BLOCK_IDENTIFIERS,
26+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS,
27+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS,
28+
)
2329
from .hooks import HookRegistry, ModelHook
2430

2531

2632
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2733

2834

29-
_ATTENTION_CLASSES = (Attention, MochiAttention)
30-
31-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
32-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
33-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
34-
35-
3635
@dataclass
3736
class PyramidAttentionBroadcastConfig:
3837
r"""

0 commit comments

Comments
 (0)