Skip to content

Commit 58933b5

Browse files
committed
update
1 parent 2ff1716 commit 58933b5

File tree

3 files changed

+127
-16
lines changed

3 files changed

+127
-16
lines changed

src/diffusers/hooks/_common.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33

44
_ATTENTION_CLASSES = (Attention, MochiAttention)
55

6-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
7-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
8-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
6+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
7+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
8+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
9+
10+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
11+
{
12+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
13+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
14+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
15+
}
16+
)

src/diffusers/hooks/enhance_a_video.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
1516
from dataclasses import dataclass
1617
from enum import Enum
17-
from typing import Callable
18+
from typing import Callable, Dict, Union
1819

1920
import torch
2021
import torch.overrides
@@ -30,7 +31,7 @@
3031
_ENHANCE_A_VIDEO = "enhance_a_video"
3132

3233

33-
class _AttentionType(Enum):
34+
class _AttentionType(int, Enum):
3435
SELF = 0
3536
JOINT___LATENTS_FIRST = 1
3637
JOINT___LATENTS_LAST = 2
@@ -40,9 +41,30 @@ class _AttentionType(Enum):
4041
class EnhanceAVideoConfig:
4142
r"""
4243
Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508).
44+
45+
Args:
46+
weight (`float` or `Dict[str, float]`, defaults to `1.0`):
47+
The weighting factor for the Enhance A Video score. If a `float`, the same weight is applied to all layers.
48+
If a `dict`, the keys are regex patterns that match non-overlapping layer names, and the values are the
49+
corresponding weights.
50+
num_frames_callback (`Callable[[], int]`, `optional`):
51+
A callback function that returns the number of latent frames in the latent video stream. Since there is no
52+
easy way to deduce this within the attention layers, the user must provide this information.
53+
_attention_type (`_AttentionType`, defaults to `_AttentionType.SELF`):
54+
The type of attention mechanism that the underlying model uses. The following options are available:
55+
- `_AttentionType.SELF`:
56+
The model uses self-attention layers with only video tokens.
57+
- `_AttentionType.JOINT___LATENTS_FIRST`:
58+
The model uses joint attention layers (concatenated video and text stream data) with video tokens
59+
first.
60+
- `_AttentionType.JOINT___LATENTS_LAST`:
61+
The model uses joint attention layers (concatenated video and text stream data) with video tokens
62+
last.
63+
64+
This parameter is not backwards-compatible and may be subject to change in future versions.
4365
"""
4466

45-
weight: float = 1.0
67+
weight: Union[float, Dict[str, float]] = 1.0
4668
num_frames_callback: Callable[[], int] = None
4769
_attention_type: _AttentionType = _AttentionType.SELF
4870

@@ -59,7 +81,7 @@ def reset(self) -> None:
5981
self.latents_sequence_length = None
6082

6183
def __repr__(self):
62-
return f"EnhanceAVideoAttentionState(scores={self.scores}, latents_sequence_length={self.latents_sequence_length})"
84+
return f"EnhanceAVideoAttentionState(latents_sequence_length={self.latents_sequence_length})"
6385

6486

6587
class EnhanceAVideoCaptureSDPAInputsFunctionMode(torch.overrides.TorchFunctionMode):
@@ -71,6 +93,7 @@ def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor
7193
def __torch_function__(self, func, types, args=(), kwargs=None):
7294
if kwargs is None:
7395
kwargs = {}
96+
# TODO(aryan): revisit for torch.compile -- can trace into ATen but not triton, so this branch is never hit
7497
if func is torch.nn.functional.scaled_dot_product_attention:
7598
query = kwargs.get("query", None) or args[0]
7699
key = kwargs.get("key", None) or args[1]
@@ -97,7 +120,7 @@ def new_forward(self, module, *args, **kwargs):
97120
# 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
98121
kwargs_hidden_states = kwargs.get("hidden_states", None)
99122
hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args[0]
100-
self.state.latents_sequence_length = hidden_states.size(2)
123+
self.state.latents_sequence_length = hidden_states.size(1)
101124

102125
# Capture query and key tensors to compute EnhanceAVideo scores
103126
with EnhanceAVideoCaptureSDPAInputsFunctionMode(self._query_key_capture_callback):
@@ -130,7 +153,6 @@ def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor:
130153
query = reshape_for_framewise_attention(query)
131154
key = reshape_for_framewise_attention(key)
132155
scores = enhance_a_video_score(query, key, num_frames, self.weight)
133-
print("Applying scores:", scores)
134156
hidden_states = hidden_states * scores
135157

136158
return (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states
@@ -176,24 +198,105 @@ def enhance_a_video_score(
176198

177199

178200
def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig) -> None:
201+
r"""
202+
Applies [Enhance A Video](https://huggingface.co/papers/2502.07508) on a model.
203+
204+
This function applies a Diffusers Hook to all/user-configured self-attention lyaers of the model. The hook captures
205+
the inputs entering `torch.nn.functional.scaled_dot_product_attention` and operates on them.
206+
207+
Args:
208+
module (`torch.nn.Module`):
209+
The model to apply Enhance A Video to. It must be a video generation model.
210+
config (`EnhanceAVideoConfig`):
211+
The configuration for Enhance A Video.
212+
213+
Example:
214+
215+
```python
216+
>>> import torch
217+
>>> from diffusers import LTXPipeline
218+
>>> from diffusers.hooks import apply_enhance_a_video, remove_enhance_a_video, EnhanceAVideoConfig
219+
>>> from diffusers.utils import export_to_video
220+
221+
>>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
222+
>>> pipe.to("cuda")
223+
224+
>>> weight = 4
225+
>>> num_frames = 161
226+
>>> latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
227+
228+
>>> # Apply Enhance-A-Video to all layers with a weight of 4
229+
>>> config = EnhanceAVideoConfig(weight=weight, num_frames_callback=lambda: latent_num_frames, _attention_type=1)
230+
>>> apply_enhance_a_video(pipe.transformer, config)
231+
232+
>>> prompt = "A man standing in a sunlit garden, surrounded by lush greenery and colorful flowers. The man has a knife in his hand and is cutting a ripe, juicy watermelon. The watermelon is bright red and contrasts beautifully with the green foliage in the background."
233+
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
234+
235+
>>> video = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=161).frames[0]
236+
>>> export_to_video(video, "output.mp4", fps=24)
237+
238+
>>> # Remove Enhance-A-Video
239+
>>> remove_enhance_a_video(pipe.transformer)
240+
241+
>>> # Apply Enhance-A-Video to specific layers with different weights
242+
>>> config = EnhanceAVideoConfig(
243+
... weight={
244+
... "blocks\.(0|1|2|3|4|5|6|7)\.": 5.0,
245+
... "blocks\.(10|11|12|13|14|15)\.": 8.0,
246+
... "blocks\.(21|22|23|24|25|26)\.": 3.0,
247+
... },
248+
... num_frames_callback=lambda: latent_num_frames,
249+
... _attention_type=1,
250+
... )
251+
```
252+
"""
253+
weight = config.weight
254+
if not isinstance(weight, dict):
255+
weight = {".*": config.weight}
256+
_validate_weight(module, weight)
257+
258+
weight_keys = set(weight.keys())
179259
for name, submodule in module.named_modules():
260+
# We cannot apply Enhance-A-Video to cross-attention layers
180261
is_cross_attention = getattr(submodule, "is_cross_attention", False)
181262
if not isinstance(submodule, _ATTENTION_CLASSES) or is_cross_attention:
182263
continue
264+
current_weight = next(
265+
(weight[identifier] for identifier in weight_keys if re.search(identifier, name) is not None), None
266+
)
267+
if current_weight is None:
268+
continue
183269
logger.debug(f"Applying Enhance-A-Video to layer '{name}'")
184270
hook_registry = HookRegistry.check_if_exists_or_initialize(submodule)
185271
hook = EnhanceAVideoSDPAHook(
186-
weight=config.weight,
272+
weight=current_weight,
187273
num_frames_callback=config.num_frames_callback,
188274
_attention_type=config._attention_type,
189275
)
190276
hook_registry.register_hook(hook, _ENHANCE_A_VIDEO)
191277

192278

193279
def remove_enhance_a_video(module: torch.nn.Module) -> None:
280+
r"""
281+
Removes the Enhance A Video hook from the model.
282+
283+
See [`~hooks.enhance_a_video.apply_enhance_a_video`] for an example.
284+
"""
194285
for name, submodule in module.named_modules():
195286
if not hasattr(submodule, "_diffusers_hook"):
196287
continue
197288
hook_registry = submodule._diffusers_hook
198289
hook_registry.remove_hook(_ENHANCE_A_VIDEO, recurse=False)
199290
logger.debug(f"Removed Enhance-A-Video from layer '{name}'")
291+
292+
293+
def _validate_weight(module: torch.nn.Module, weight: Dict[str, float]) -> None:
294+
if not isinstance(weight, dict):
295+
raise ValueError(f"Invalid weight type: {type(weight)}")
296+
weight_keys = set(weight.keys())
297+
for name, _ in module.named_modules():
298+
num_matches = sum(re.search(identifier, name) is not None for identifier in weight_keys)
299+
if num_matches > 1:
300+
raise ValueError(
301+
f"The provided weight dictionary has multiple regex matches for layer '{name}'. Please provide non-overlapping regex patterns."
302+
)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from ..utils import logging
2323
from ._common import (
2424
_ATTENTION_CLASSES,
25-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS,
26-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS,
27-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS,
25+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
26+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
2828
)
2929
from .hooks import HookRegistry, ModelHook
3030

@@ -75,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
7575
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7676
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
7777

78-
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
79-
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
80-
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
78+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
79+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
80+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
8181

8282
current_timestep_callback: Callable[[], int] = None
8383

0 commit comments

Comments
 (0)