|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from dataclasses import dataclass |
| 16 | +from enum import Enum |
| 17 | +from typing import Callable |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.overrides |
| 21 | + |
| 22 | +from ..utils import get_logger |
| 23 | +from ._common import _ATTENTION_CLASSES |
| 24 | +from .hooks import HookRegistry, ModelHook |
| 25 | + |
| 26 | + |
| 27 | +logger = get_logger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +_ENHANCE_A_VIDEO = "enhance_a_video" |
| 31 | +_ENHANCE_A_VIDEO_SDPA = "enhance_a_video_sdpa" |
| 32 | + |
| 33 | + |
| 34 | +class _AttentionType(Enum): |
| 35 | + SELF = 0 |
| 36 | + JOINT___LATENTS_FIRST = 1 |
| 37 | + JOINT___LATENTS_LAST = 2 |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class EnhanceAVideoConfig: |
| 42 | + r""" |
| 43 | + Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508). |
| 44 | + """ |
| 45 | + |
| 46 | + weight: float = 1.0 |
| 47 | + num_frames_callback: Callable[[], int] = None |
| 48 | + _attention_type: _AttentionType = _AttentionType.SELF |
| 49 | + |
| 50 | + |
| 51 | +class EnhanceAVideoAttentionState: |
| 52 | + def __init__(self) -> None: |
| 53 | + self.query: torch.Tensor = None |
| 54 | + self.key: torch.Tensor = None |
| 55 | + self.latents_sequence_length: int = None |
| 56 | + |
| 57 | + def reset(self) -> None: |
| 58 | + self.query = None |
| 59 | + self.key = None |
| 60 | + self.latents_sequence_length = None |
| 61 | + |
| 62 | + def __repr__(self): |
| 63 | + return f"EnhanceAVideoAttentionState(scores={self.scores}, latents_sequence_length={self.latents_sequence_length})" |
| 64 | + |
| 65 | + |
| 66 | +class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode): |
| 67 | + def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor], None]) -> None: |
| 68 | + super().__init__() |
| 69 | + |
| 70 | + self.query_key_save_callback = query_key_save_callback |
| 71 | + |
| 72 | + def __torch_function__(self, func, types, args=(), kwargs=None): |
| 73 | + if kwargs is None: |
| 74 | + kwargs = {} |
| 75 | + if func is torch.nn.functional.scaled_dot_product_attention: |
| 76 | + query = kwargs.get("query", None) or args[0] |
| 77 | + key = kwargs.get("key", None) or args[1] |
| 78 | + self.query_key_save_callback(query, key) |
| 79 | + return func(*args, **(kwargs or {})) |
| 80 | + |
| 81 | + |
| 82 | +class EnhanceAVideoSDPAHook(ModelHook): |
| 83 | + _is_stateful = True |
| 84 | + |
| 85 | + def __init__(self, weight: float, num_frames_callback: Callable[[], int], _attention_type: _AttentionType) -> None: |
| 86 | + self.weight = weight |
| 87 | + self.num_frames_callback = num_frames_callback |
| 88 | + self._attention_type = _attention_type |
| 89 | + |
| 90 | + def initialize_hook(self, module): |
| 91 | + self.state = EnhanceAVideoAttentionState() |
| 92 | + return module |
| 93 | + |
| 94 | + def new_forward(self, module, *args, **kwargs): |
| 95 | + # Here, query and key have two shapes (considering the general diffusers-style model implementation): |
| 96 | + # 1. [batch_size, attention_heads, latents_sequence_length, head_dim] |
| 97 | + # 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim] |
| 98 | + # 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim] |
| 99 | + kwargs_hidden_states = kwargs.get("hidden_states", None) |
| 100 | + hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0] |
| 101 | + self.state.latents_sequence_length = hidden_states.size(2) |
| 102 | + |
| 103 | + # Capture query and key tensors to compute EnhanceAVideo scores |
| 104 | + with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback): |
| 105 | + return self.fn_ref.original_forward(*args, **kwargs) |
| 106 | + |
| 107 | + def post_forward(self, module, output): |
| 108 | + # For diffusers models, or ones that are implemented similar to our design, we either return: |
| 109 | + # 1. A single output: `hidden_states` |
| 110 | + # 2. A tuple of outputs: `(hidden_states, encoder_hidden_states)`. |
| 111 | + # We need to handle both cases of applying EnhanceAVideo scores. |
| 112 | + hidden_states = output[0] if isinstance(output, tuple) else output |
| 113 | + |
| 114 | + def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor: |
| 115 | + # This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations. |
| 116 | + # [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C] |
| 117 | + return tensor.unflatten(2, (num_frames, -1)).permute(0, 3, 1, 2, 4).flatten(0, 1) |
| 118 | + |
| 119 | + # Handle reshaping of query and key tensors |
| 120 | + query, key = self.state.query, self.state.key |
| 121 | + if self._attention_type == _AttentionType.SELF: |
| 122 | + pass |
| 123 | + elif self._attention_type == _AttentionType.JOINT___LATENTS_FIRST: |
| 124 | + query = query[:, :, : self.state.latents_sequence_length] |
| 125 | + key = key[:, :, : self.state.latents_sequence_length] |
| 126 | + elif self._attention_type == _AttentionType.JOINT___LATENTS_LAST: |
| 127 | + query = query[:, :, -self.state.latents_sequence_length :] |
| 128 | + key = key[:, :, -self.state.latents_sequence_length :] |
| 129 | + |
| 130 | + num_frames = self.num_frames_callback() |
| 131 | + query = reshape_for_framewise_attention(query) |
| 132 | + key = reshape_for_framewise_attention(key) |
| 133 | + scores = enhance_a_video_score(query, key, num_frames, self.weight) |
| 134 | + print("Applying scores:", scores) |
| 135 | + hidden_states = hidden_states * scores |
| 136 | + |
| 137 | + return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states |
| 138 | + |
| 139 | + def reset_state(self, module): |
| 140 | + self.state.reset() |
| 141 | + return module |
| 142 | + |
| 143 | + def _query_key_capture_callback(self, query: torch.Tensor, key: torch.Tensor) -> None: |
| 144 | + self.state.query = query |
| 145 | + self.state.key = key |
| 146 | + |
| 147 | + |
| 148 | +def enhance_a_video_score( |
| 149 | + query: torch.Tensor, key: torch.Tensor, num_frames: int, weight: float = 1.0 |
| 150 | +) -> torch.Tensor: |
| 151 | + head_dim = query.size(-1) |
| 152 | + scale = 1 / (head_dim**0.5) |
| 153 | + query = query * scale |
| 154 | + |
| 155 | + attn_temp = query @ key.transpose(-2, -1) |
| 156 | + attn_temp = attn_temp.float() |
| 157 | + attn_temp = attn_temp.softmax(dim=-1) |
| 158 | + |
| 159 | + # Reshape to [batch_size * num_tokens, num_frames, num_frames] |
| 160 | + attn_temp = attn_temp.reshape(-1, num_frames, num_frames) |
| 161 | + |
| 162 | + # Create a mask for diagonal elements |
| 163 | + diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() |
| 164 | + diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.size(0), -1, -1) |
| 165 | + |
| 166 | + # Zero out diagonal elements |
| 167 | + attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) |
| 168 | + |
| 169 | + # Calculate mean for each token's attention matrix |
| 170 | + # Number of off-diagonal elements per matrix is n*n - n |
| 171 | + num_off_diag = num_frames * num_frames - num_frames |
| 172 | + mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag |
| 173 | + |
| 174 | + scores = mean_scores.mean() * (num_frames + weight) |
| 175 | + scores = scores.clamp(min=1) |
| 176 | + return scores |
| 177 | + |
| 178 | + |
| 179 | +def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None: |
| 180 | + for name, submodule in module.named_modules(): |
| 181 | + is_cross_attention = getattr(submodule, "is_cross_attention", False) |
| 182 | + if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention: |
| 183 | + continue |
| 184 | + logger.debug(f"Applying Enhance-A-Video to layer '{name}'") |
| 185 | + hook_registry = HookRegistry.check_if_exists_or_initialize(submodule) |
| 186 | + hook = EnhanceAVideoSDPAHook( |
| 187 | + weight=config.weight, |
| 188 | + num_frames_callback=config.num_frames_callback, |
| 189 | + _attention_type=config._attention_type, |
| 190 | + ) |
| 191 | + hook_registry.register_hook(hook, _ENHANCE_A_VIDEO_SDPA) |
0 commit comments