1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import re
1516from dataclasses import dataclass
1617from enum import Enum
17- from typing import Callable
18+ from typing import Callable , Dict , Union
1819
1920import torch
2021import torch .overrides
3031_ENHANCE_A_VIDEO = "enhance_a_video"
3132
3233
33- class _AttentionType (Enum ):
34+ class _AttentionType (int , Enum ):
3435 SELF = 0
3536 JOINT___LATENTS_FIRST = 1
3637 JOINT___LATENTS_LAST = 2
@@ -40,9 +41,30 @@ class _AttentionType(Enum):
4041class EnhanceAVideoConfig :
4142 r"""
4243 Configuration for [Enhance A Video](https://huggingface.co/papers/2502.07508).
44+
45+ Args:
46+ weight (`float` or `Dict[str, float]`, defaults to `1.0`):
47+ The weighting factor for the Enhance A Video score. If a `float`, the same weight is applied to all layers.
48+ If a `dict`, the keys are regex patterns that match non-overlapping layer names, and the values are the
49+ corresponding weights.
50+ num_frames_callback (`Callable[[], int]`, `optional`):
51+ A callback function that returns the number of latent frames in the latent video stream. Since there is no
52+ easy way to deduce this within the attention layers, the user must provide this information.
53+ _attention_type (`_AttentionType`, defaults to `_AttentionType.SELF`):
54+ The type of attention mechanism that the underlying model uses. The following options are available:
55+ - `_AttentionType.SELF`:
56+ The model uses self-attention layers with only video tokens.
57+ - `_AttentionType.JOINT___LATENTS_FIRST`:
58+ The model uses joint attention layers (concatenated video and text stream data) with video tokens
59+ first.
60+ - `_AttentionType.JOINT___LATENTS_LAST`:
61+ The model uses joint attention layers (concatenated video and text stream data) with video tokens
62+ last.
63+
64+ This parameter is not backwards-compatible and may be subject to change in future versions.
4365 """
4466
45- weight : float = 1.0
67+ weight : Union [ float , Dict [ str , float ]] = 1.0
4668 num_frames_callback : Callable [[], int ] = None
4769 _attention_type : _AttentionType = _AttentionType .SELF
4870
@@ -59,7 +81,7 @@ def reset(self) -> None:
5981 self .latents_sequence_length = None
6082
6183 def __repr__ (self ):
62- return f"EnhanceAVideoAttentionState(scores= { self . scores } , latents_sequence_length={ self .latents_sequence_length } )"
84+ return f"EnhanceAVideoAttentionState(latents_sequence_length={ self .latents_sequence_length } )"
6385
6486
6587class EnhanceAVideoCaptureSDPAInputsFunctionMode (torch .overrides .TorchFunctionMode ):
@@ -71,6 +93,7 @@ def __init__(self, query_key_save_callback: Callable[[torch.Tensor, torch.Tensor
7193 def __torch_function__ (self , func , types , args = (), kwargs = None ):
7294 if kwargs is None :
7395 kwargs = {}
96+ # TODO(aryan): revisit for torch.compile -- can trace into ATen but not triton, so this branch is never hit
7497 if func is torch .nn .functional .scaled_dot_product_attention :
7598 query = kwargs .get ("query" , None ) or args [0 ]
7699 key = kwargs .get ("key" , None ) or args [1 ]
@@ -97,7 +120,7 @@ def new_forward(self, module, *args, **kwargs):
97120 # 3. [batch_size, attention_heads, encoder_sequence_length + latents_sequence_length, head_dim]
98121 kwargs_hidden_states = kwargs .get ("hidden_states" , None )
99122 hidden_states = kwargs_hidden_states if kwargs_hidden_states is not None else args [0 ]
100- self .state .latents_sequence_length = hidden_states .size (2 )
123+ self .state .latents_sequence_length = hidden_states .size (1 )
101124
102125 # Capture query and key tensors to compute EnhanceAVideo scores
103126 with EnhanceAVideoCaptureSDPAInputsFunctionMode (self ._query_key_capture_callback ):
@@ -130,7 +153,6 @@ def reshape_for_framewise_attention(tensor: torch.Tensor) -> torch.Tensor:
130153 query = reshape_for_framewise_attention (query )
131154 key = reshape_for_framewise_attention (key )
132155 scores = enhance_a_video_score (query , key , num_frames , self .weight )
133- print ("Applying scores:" , scores )
134156 hidden_states = hidden_states * scores
135157
136158 return (hidden_states , * output [1 :]) if isinstance (output , tuple ) else hidden_states
@@ -176,24 +198,105 @@ def enhance_a_video_score(
176198
177199
178200def apply_enhance_a_video (module : torch .nn .Module , config : EnhanceAVideoConfig ) -> None :
201+ r"""
202+ Applies [Enhance A Video](https://huggingface.co/papers/2502.07508) on a model.
203+
204+ This function applies a Diffusers Hook to all/user-configured self-attention lyaers of the model. The hook captures
205+ the inputs entering `torch.nn.functional.scaled_dot_product_attention` and operates on them.
206+
207+ Args:
208+ module (`torch.nn.Module`):
209+ The model to apply Enhance A Video to. It must be a video generation model.
210+ config (`EnhanceAVideoConfig`):
211+ The configuration for Enhance A Video.
212+
213+ Example:
214+
215+ ```python
216+ >>> import torch
217+ >>> from diffusers import LTXPipeline
218+ >>> from diffusers.hooks import apply_enhance_a_video, remove_enhance_a_video, EnhanceAVideoConfig
219+ >>> from diffusers.utils import export_to_video
220+
221+ >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
222+ >>> pipe.to("cuda")
223+
224+ >>> weight = 4
225+ >>> num_frames = 161
226+ >>> latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
227+
228+ >>> # Apply Enhance-A-Video to all layers with a weight of 4
229+ >>> config = EnhanceAVideoConfig(weight=weight, num_frames_callback=lambda: latent_num_frames, _attention_type=1)
230+ >>> apply_enhance_a_video(pipe.transformer, config)
231+
232+ >>> prompt = "A man standing in a sunlit garden, surrounded by lush greenery and colorful flowers. The man has a knife in his hand and is cutting a ripe, juicy watermelon. The watermelon is bright red and contrasts beautifully with the green foliage in the background."
233+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
234+
235+ >>> video = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=161).frames[0]
236+ >>> export_to_video(video, "output.mp4", fps=24)
237+
238+ >>> # Remove Enhance-A-Video
239+ >>> remove_enhance_a_video(pipe.transformer)
240+
241+ >>> # Apply Enhance-A-Video to specific layers with different weights
242+ >>> config = EnhanceAVideoConfig(
243+ ... weight={
244+ ... "blocks\.(0|1|2|3|4|5|6|7)\.": 5.0,
245+ ... "blocks\.(10|11|12|13|14|15)\.": 8.0,
246+ ... "blocks\.(21|22|23|24|25|26)\.": 3.0,
247+ ... },
248+ ... num_frames_callback=lambda: latent_num_frames,
249+ ... _attention_type=1,
250+ ... )
251+ ```
252+ """
253+ weight = config .weight
254+ if not isinstance (weight , dict ):
255+ weight = {".*" : config .weight }
256+ _validate_weight (module , weight )
257+
258+ weight_keys = set (weight .keys ())
179259 for name , submodule in module .named_modules ():
260+ # We cannot apply Enhance-A-Video to cross-attention layers
180261 is_cross_attention = getattr (submodule , "is_cross_attention" , False )
181262 if not isinstance (submodule , _ATTENTION_CLASSES ) or is_cross_attention :
182263 continue
264+ current_weight = next (
265+ (weight [identifier ] for identifier in weight_keys if re .search (identifier , name ) is not None ), None
266+ )
267+ if current_weight is None :
268+ continue
183269 logger .debug (f"Applying Enhance-A-Video to layer '{ name } '" )
184270 hook_registry = HookRegistry .check_if_exists_or_initialize (submodule )
185271 hook = EnhanceAVideoSDPAHook (
186- weight = config . weight ,
272+ weight = current_weight ,
187273 num_frames_callback = config .num_frames_callback ,
188274 _attention_type = config ._attention_type ,
189275 )
190276 hook_registry .register_hook (hook , _ENHANCE_A_VIDEO )
191277
192278
193279def remove_enhance_a_video (module : torch .nn .Module ) -> None :
280+ r"""
281+ Removes the Enhance A Video hook from the model.
282+
283+ See [`~hooks.enhance_a_video.apply_enhance_a_video`] for an example.
284+ """
194285 for name , submodule in module .named_modules ():
195286 if not hasattr (submodule , "_diffusers_hook" ):
196287 continue
197288 hook_registry = submodule ._diffusers_hook
198289 hook_registry .remove_hook (_ENHANCE_A_VIDEO , recurse = False )
199290 logger .debug (f"Removed Enhance-A-Video from layer '{ name } '" )
291+
292+
293+ def _validate_weight (module : torch .nn .Module , weight : Dict [str , float ]) -> None :
294+ if not isinstance (weight , dict ):
295+ raise ValueError (f"Invalid weight type: { type (weight )} " )
296+ weight_keys = set (weight .keys ())
297+ for name , _ in module .named_modules ():
298+ num_matches = sum (re .search (identifier , name ) is not None for identifier in weight_keys )
299+ if num_matches > 1 :
300+ raise ValueError (
301+ f"The provided weight dictionary has multiple regex matches for layer '{ name } '. Please provide non-overlapping regex patterns."
302+ )
0 commit comments