diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 56be0bbdf305..e54485232d5b 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .enhance_a_video import EnhanceAVideoConfig, apply_enhance_a_video, remove_enhance_a_video from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..9b92bfc07850 --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,16 @@ +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/enhance_a_video.py b/src/diffusers/hooks/enhance_a_video.py new file mode 100644 index 000000000000..a9286417c1a5 --- /dev/null +++ b/src/diffusers/hooks/enhance_a_video.py @@ -0,0 +1,302 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Union + +import torch +import torch.overrides + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) + + +_ENHANCE_A_VIDEO = "enhance_a_video" + + +class _AttentionType(int, Enum): + SELF = 0 + JOINT___LATENTS_FIRST = 1 + JOINT___LATENTS_LAST = 2 + + +@dataclass +class EnhanceAVideoConfig: + r""" + Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508). + + Args: + weight (`float` or `Dict[str, float]`, defaults to `1.0`): + The weighting factor for the Enhance A Video score. If a `float`, the same weight is applied to all layers. + If a `dict`, the keys are regex patterns that match non-overlapping layer names, and the values are the + corresponding weights. + num_frames_callback (`Callable[[], int]`, `optional`): + A callback function that returns the number of latent frames in the latent video stream. Since there is no + easy way to deduce this within the attention layers, the user must provide this information. + _attention_type (`_AttentionType`, defaults to `_AttentionType.SELF`): + The type of attention mechanism that the underlying model uses. The following options are available: + - `_AttentionType.SELF`: + The model uses self-attention layers with only video tokens. + - `_AttentionType.JOINT___LATENTS_FIRST`: + The model uses joint attention layers (concatenated video and text stream data) with video tokens + first. + - `_AttentionType.JOINT___LATENTS_LAST`: + The model uses joint attention layers (concatenated video and text stream data) with video tokens + last. + + This parameter is not backwards-compatible and may be subject to change in future versions. + """ + + weight: Union[float, Dict[str, float]] = 1.0 + num_frames_callback: Callable[[], int] = None + _attention_type: _AttentionType = _AttentionType.SELF + + +class EnhanceAVideoAttentionState: + def __init__(self) -> None: + self.query: torch.Tensor = None + self.key: torch.Tensor = None + self.latents_sequence_length: int = None + + def reset(self) -> None: + self.query = None + self.key = None + self.latents_sequence_length = None + + def __repr__(self): + return f"EnhanceAVideoAttentionState(latents_sequence_length={self.latents_sequence_length})" + + +class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode): + def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor], None]) -> None: + super().__init__() + + self.query_key_save_callback = query_key_save_callback + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # TODO(aryan): revisit for torch.compile -- can trace into ATen but not triton, so this branch is never hit + if func is torch.nn.functional.scaled_dot_product_attention: + query = kwargs.get("query", None) or args[0] + key = kwargs.get("key", None) or args[1] + self.query_key_save_callback(query, key) + return func(*args, **kwargs) + + +class EnhanceAVideoSDPAHook(ModelHook): + _is_stateful = True + + def __init__(self, weight: float, num_frames_callback: Callable[[], int], _attention_type: _AttentionType) -> None: + self.weight = weight + self.num_frames_callback = num_frames_callback + self._attention_type = _attention_type + + def initialize_hook(self, module): + self.state = EnhanceAVideoAttentionState() + return module + + def new_forward(self, module, *args, **kwargs): + # Here, query and key have two shapes (considering the general diffusers-style model implementation): + # 1. [batch_size, attention_heads, latents_sequence_length, head_dim] + # 2. [batch_size, attention_heads, latents_sequence_length + encoder_sequence_length, head_dim] + # 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim] + kwargs_hidden_states = kwargs.get("hidden_states", None) + hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0] + self.state.latents_sequence_length = hidden_states.size(1) + + # Capture query and key tensors to compute EnhanceAVideo scores + with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback): + return self.fn_ref.original_forward(*args, **kwargs) + + def post_forward(self, module, output): + # For diffusers models, or ones that are implemented similar to our design, we either return: + # 1. A single output: `hidden_states` + # 2. A tuple of outputs: `(hidden_states, encoder_hidden_states)`. + # We need to handle both cases of applying EnhanceAVideo scores. + hidden_states = output[0] if isinstance(output, tuple) else output + + def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor: + # This code assumes tensor is [B, N, S, C]. This should be true for most diffusers-style implementations. + # [B, N, S, C] -> [B, N, F, S, C] -> [B, S, N, F, C] -> [B * S, N, F, C] + return tensor.unflatten(2, (num_frames, -1)).permute(0, 3, 1, 2, 4).flatten(0, 1) + + # Handle reshaping of query and key tensors + query, key = self.state.query, self.state.key + if self._attention_type == _AttentionType.SELF: + pass + elif self._attention_type == _AttentionType.JOINT___LATENTS_FIRST: + query = query[:, :, : self.state.latents_sequence_length] + key = key[:, :, : self.state.latents_sequence_length] + elif self._attention_type == _AttentionType.JOINT___LATENTS_LAST: + query = query[:, :, -self.state.latents_sequence_length :] + key = key[:, :, -self.state.latents_sequence_length :] + + num_frames = self.num_frames_callback() + query = reshape_for_framewise_attention(query) + key = reshape_for_framewise_attention(key) + scores = enhance_a_video_score(query, key, num_frames, self.weight) + hidden_states = hidden_states * scores + + return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states + + def reset_state(self, module): + self.state.reset() + return module + + def _query_key_capture_callback(self, query: torch.Tensor, key: torch.Tensor) -> None: + self.state.query = query + self.state.key = key + + +def enhance_a_video_score( + query: torch.Tensor, key: torch.Tensor, num_frames: int, weight: float = 1.0 +) -> torch.Tensor: + head_dim = query.size(-1) + scale = 1 / (head_dim**0.5) + query = query * scale + + attn_temp = query @ key.transpose(-2, -1) + attn_temp = attn_temp.float() + attn_temp = attn_temp.softmax(dim=-1) + + # Reshape to [batch_size * num_tokens, num_frames, num_frames] + attn_temp = attn_temp.reshape(-1, num_frames, num_frames) + + # Create a mask for diagonal elements + diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() + diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.size(0), -1, -1) + + # Zero out diagonal elements + attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) + + # Calculate mean for each token's attention matrix + # Number of off-diagonal elements per matrix is n*n - n + num_off_diag = num_frames * num_frames - num_frames + mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag + + scores = mean_scores.mean() * (num_frames + weight) + scores = scores.clamp(min=1) + return scores + + +def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None: + r""" + Applies [Enhance A Video](https://huggingface.co/papers/2502.07508) on a model. + + This function applies a Diffusers Hook to all/user-configured self-attention lyaers of the model. The hook captures + the inputs entering `torch.nn.functional.scaled_dot_product_attention` and operates on them. + + Args: + module (`torch.nn.Module`): + The model to apply Enhance A Video to. It must be a video generation model. + config (`EnhanceAVideoConfig`): + The configuration for Enhance A Video. + + Example: + + ```python + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.hooks import apply_enhance_a_video, remove_enhance_a_video, EnhanceAVideoConfig + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> weight = 4 + >>> num_frames = 161 + >>> latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1 + + >>> # Apply Enhance-A-Video to all layers with a weight of 4 + >>> config = EnhanceAVideoConfig(weight=weight, num_frames_callback=lambda: latent_num_frames, _attention_type=1) + >>> apply_enhance_a_video(pipe.transformer, config) + + >>> prompt = "A man standing in a sunlit garden, surrounded by lush greenery and colorful flowers. The man has a knife in his hand and is cutting a ripe, juicy watermelon. The watermelon is bright red and contrasts beautifully with the green foliage in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=161).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + + >>> # Remove Enhance-A-Video + >>> remove_enhance_a_video(pipe.transformer) + + >>> # Apply Enhance-A-Video to specific layers with different weights + >>> config = EnhanceAVideoConfig( + ... weight={ + ... "blocks\.(0|1|2|3|4|5|6|7)\.": 5.0, + ... "blocks\.(10|11|12|13|14|15)\.": 8.0, + ... "blocks\.(21|22|23|24|25|26)\.": 3.0, + ... }, + ... num_frames_callback=lambda: latent_num_frames, + ... _attention_type=1, + ... ) + ``` + """ + weight = config.weight + if not isinstance(weight, dict): + weight = {".*": config.weight} + _validate_weight(module, weight) + + weight_keys = set(weight.keys()) + for name, submodule in module.named_modules(): + # We cannot apply Enhance-A-Video to cross-attention layers + is_cross_attention = getattr(submodule, "is_cross_attention", False) + if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention: + continue + current_weight = next( + (weight[identifier] for identifier in weight_keys if re.search(identifier, name) is not None), None + ) + if current_weight is None: + continue + logger.debug(f"Applying Enhance-A-Video to layer '{name}'") + hook_registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = EnhanceAVideoSDPAHook( + weight=current_weight, + num_frames_callback=config.num_frames_callback, + _attention_type=config._attention_type, + ) + hook_registry.register_hook(hook, _ENHANCE_A_VIDEO) + + +def remove_enhance_a_video(module: torch.nn.Module) -> None: + r""" + Removes the Enhance A Video hook from the model. + + See [`~hooks.enhance_a_video.apply_enhance_a_video`] for an example. + """ + for name, submodule in module.named_modules(): + if not hasattr(submodule, "_diffusers_hook"): + continue + hook_registry = submodule._diffusers_hook + hook_registry.remove_hook(_ENHANCE_A_VIDEO, recurse=False) + logger.debug(f"Removed Enhance-A-Video from layer '{name}'") + + +def _validate_weight(module: torch.nn.Module, weight: Dict[str, float]) -> None: + if not isinstance(weight, dict): + raise ValueError(f"Invalid weight type: {type(weight)}") + weight_keys = set(weight.keys()) + for name, _ in module.named_modules(): + num_matches = sum(re.search(identifier, name) is not None for identifier in weight_keys) + if num_matches > 1: + raise ValueError( + f"The provided weight dictionary has multiple regex matches for layer '{name}'. Please provide non-overlapping regex patterns." + ) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 9f8597d52f8c..e6c06aaa4456 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -20,19 +20,18 @@ from ..models.attention_processor import Attention, MochiAttention from ..utils import logging +from ._common import ( + _ATTENTION_CLASSES, + _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, +) from .hooks import HookRegistry, ModelHook logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_ATTENTION_CLASSES = (Attention, MochiAttention) - -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) -_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") - - @dataclass class PyramidAttentionBroadcastConfig: r""" @@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig: temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) - spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS - temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS - cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS current_timestep_callback: Callable[[], int] = None diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 61d8d076aab0..368334eb21b0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -34,7 +34,6 @@ from typing_extensions import Self from .. import __version__ -from ..hooks import apply_group_offloading, apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -414,6 +413,7 @@ def enable_layerwise_casting( non_blocking (`bool`, *optional*, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ + from ..hooks import apply_layerwise_casting user_provided_patterns = True if skip_modules_pattern is None: @@ -479,6 +479,8 @@ def enable_group_offload( ... ) ``` """ + from ..hooks import apply_group_offloading + if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: msg = ( "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "