diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..2b4351d4a94e --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,251 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, Dict, Tuple, Union + +import torch + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + """ + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + def reset_state(self, module): + for hook in self.hooks: + module = hook.reset_state(module) + return module + + +class PyramidAttentionBroadcastHook(ModelHook): + def __init__( + self, + skip_range: int, + timestep_range: Tuple[int, int], + timestep_callback: Callable[[], Union[torch.LongTensor, int]], + ) -> None: + super().__init__() + + self.skip_range = skip_range + self.timestep_range = timestep_range + self.timestep_callback = timestep_callback + + self.attention_cache = None + self._iteration = 0 + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + current_timestep = self.timestep_callback() + is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] + should_compute_attention = self._iteration % self.skip_range == 0 + + if not is_within_timestep_range or should_compute_attention: + output = module._old_forward(*args, **kwargs) + else: + output = self.attention_cache + + self._iteration = self._iteration + 1 + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + self.attention_cache = output + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.attention_cache = None + self._iteration = 0 + return module + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 9314960f9618..10dd6455092d 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -38,6 +38,7 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import AllegroPipelineOutput @@ -131,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AllegroPipeline(DiffusionPipeline): +class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using Allegro. @@ -786,6 +787,7 @@ def __call__( negative_prompt_attention_mask, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -863,6 +865,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -901,6 +904,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3655075bd519..9eeccec50621 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -30,6 +30,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -144,7 +145,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for controlled text-to-video generation using CogVideoX Fun. @@ -650,6 +651,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -730,6 +732,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -779,6 +782,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e1858b16148..aa790c830d1a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1082,6 +1082,10 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ + + if hasattr(self, "_diffusers_hook"): + self._diffusers_hook.reset_state() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 6e917568f33a..7fdb6a7f5b93 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,106 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Optional, Tuple +from typing import List, Optional, Tuple -import torch import torch.nn as nn -from ..models.attention_processor import Attention, AttentionProcessor +from ..models.attention_processor import Attention +from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module from ..utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class PyramidAttentionBroadcastAttentionProcessorWrapper: - r""" - Helper attention processor that wraps logic required for Pyramid Attention Broadcast to function. - - PAB works by caching and re-using attention computations from past inference steps. This is due to the realization - that the attention states do not differ too much numerically between successive inference steps. The difference is - most significant/prominent in the spatial attention blocks, lesser so in the temporal attention blocks, and least - in cross attention blocks. - - Currently, only spatial and cross attention block skipping is supported in Diffusers due to not having any models - tested with temporal attention blocks. Feel free to open a PR adding support for this in case there's a model that - you would like to use PAB with. - - Args: - pipeline ([`~diffusers.DiffusionPipeline`]): - The underlying DiffusionPipeline object that inherits from the PAB Mixin and utilized this attention - processor. - processor ([`~diffusers.models.attention_processor.AttentionProcessor`]): - The underlying attention processor that will be wrapped to cache the intermediate attention computation. - skip_range (`int`): - The attention block to execute after skipping intermediate attention blocks. If set to the value `N`, `N - - 1` attention blocks are skipped and every N'th block is executed. Different models have different - tolerances to how much attention computation can be reused based on the differences between successive - blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value - to `2` is recommended for different models PAB has been experimented with. - timestep_range (`Tuple[int, int]`): - The timestep range between which PAB will remain activated in attention blocks. While activated, PAB will - re-use attention computations between inference steps. - """ - - def __init__( - self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] - ) -> None: - self.pipeline = pipeline - self._original_processor = processor - self._skip_range = skip_range - self._timestep_range = timestep_range - - self._prev_hidden_states = None - self._iteration = 0 - - original_processor_params = set(inspect.signature(self._original_processor.__call__).parameters.keys()) - supported_parameters = { - "attn", - "hidden_states", - "encoder_hidden_states", - "attention_mask", - "temb", - "image_rotary_emb", - } - self._attn_processor_params = supported_parameters.intersection(original_processor_params) - - def __call__( - self, - attn: Attention, - hidden_states: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - r"""Method that wraps the underlying call to compute attention and cache states for re-use.""" - - if ( - hasattr(self.pipeline, "_current_timestep") - and self.pipeline._current_timestep is not None - and self._iteration % self._skip_range != 0 - and (self._timestep_range[0] < self.pipeline._current_timestep < self._timestep_range[1]) - ): - # Skip attention computation by re-using past attention states - hidden_states = self._prev_hidden_states - else: - # Perform attention computation - call_kwargs = {} - for param in self._attn_processor_params: - call_kwargs.update({param: locals()[param]}) - call_kwargs.update(kwargs) - hidden_states = self._original_processor(*args, **call_kwargs) - self._prev_hidden_states = hidden_states - - self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps - - return hidden_states - - class PyramidAttentionBroadcastMixin: r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" @@ -120,40 +32,68 @@ def _enable_pyramid_attention_broadcast(self) -> None: for name, module in denoiser.named_modules(): if isinstance(module, Attention): - logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + is_spatial_attention = ( + any(x in name for x in self._pab_spatial_attn_layer_identifiers) + and self._pab_spatial_attn_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_attention = ( + any(x in name for x in self._pab_temporal_attn_layer_identifiers) + and self._pab_temporal_attn_skip_range is not None + and not module.is_cross_attention + ) + is_cross_attention = ( + any(x in name for x in self._pab_cross_attn_layer_identifiers) + and self._pab_cross_attn_skip_range is not None + and module.is_cross_attention + ) - skip_range, timestep_range = None, None - if module.is_cross_attention and self._pab_cross_attn_skip_range is not None: - skip_range = self._pab_cross_attn_skip_range - timestep_range = self._pab_cross_attn_timestep_range - if not module.is_cross_attention and self._pab_spatial_attn_skip_range is not None: + if is_spatial_attention: skip_range = self._pab_spatial_attn_skip_range timestep_range = self._pab_spatial_attn_timestep_range + if is_temporal_attention: + skip_range = self._pab_temporal_attn_skip_range + timestep_range = self._pab_temporal_attn_timestep_range + if is_cross_attention: + skip_range = self._pab_cross_attn_skip_range + timestep_range = self._pab_cross_attn_timestep_range if skip_range is None: continue - module.set_processor( - PyramidAttentionBroadcastAttentionProcessorWrapper( - self, module.processor, skip_range, timestep_range - ) + # logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + print(f"Enabling Pyramid Attention Broadcast in layer: {name}") + + add_hook_to_module( + module, + PyramidAttentionBroadcastHook( + skip_range=skip_range, + timestep_range=timestep_range, + timestep_callback=self._pyramid_attention_broadcast_timestep_callback, + ), + append=True, ) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet for name, module in denoiser.named_modules(): - if isinstance(module, Attention) and isinstance( - module.processor, PyramidAttentionBroadcastAttentionProcessorWrapper - ): - logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") - module.processor = module.processor._original_processor + logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") + remove_hook_from_module(module) + + def _pyramid_attention_broadcast_timestep_callback(self): + return self._current_timestep def enable_pyramid_attention_broadcast( self, spatial_attn_skip_range: Optional[int] = None, + spatial_attn_timestep_range: Tuple[int, int] = (100, 800), + temporal_attn_skip_range: Optional[int] = None, cross_attn_skip_range: Optional[int] = None, - spatial_attn_timestep_range: Optional[Tuple[int, int]] = None, - cross_attn_timestep_range: Optional[Tuple[int, int]] = None, + temporal_attn_timestep_range: Tuple[int, int] = (100, 800), + cross_attn_timestep_range: Tuple[int, int] = (100, 800), + spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], + temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"], + cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], ) -> None: r""" Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation @@ -166,41 +106,53 @@ def enable_pyramid_attention_broadcast( different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `2` is recommended for different models PAB has been experimented with. + temporal_attn_skip_range (`int`, *optional*): + The attention block to execute after skipping intermediate temporal attention blocks. If set to the + value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have + different tolerances to how much attention computation can be reused based on the differences between + successive blocks. So, this parameter must be adjusted per model after performing experimentation. + Setting this value to `4` is recommended for different models PAB has been experimented with. cross_attn_skip_range (`int`, *optional*): The attention block to execute after skipping intermediate cross attention blocks. If set to the value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `6` is recommended for different models PAB has been experimented with. - spatial_attn_timestep_range (`Tuple[int, int]`, *optional*): + spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in spatial attention blocks. While - activated, PAB will re-use attention computations between inference steps. Setting this to `(100, 850)` - is recommended for different models PAB has been experimented with. - cross_attn_timestep_range (`Tuple[int, int]`, *optional*): + activated, PAB will re-use attention computations between inference steps. + temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The timestep range between which PAB will remain activated in temporal attention blocks. While + activated, PAB will re-use attention computations between inference steps. + cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in cross attention blocks. While activated, - PAB will re-use attention computations between inference steps. Setting this to `(100, 800)` is - recommended for different models PAB has been experimented with. + PAB will re-use attention computations between inference steps. """ - if spatial_attn_timestep_range is None: - spatial_attn_timestep_range = (100, 800) - if cross_attn_skip_range is None: - cross_attn_timestep_range = (100, 800) - if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: raise ValueError( "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) + if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]: + raise ValueError( + "Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." + ) if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]: raise ValueError( "Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) self._pab_spatial_attn_skip_range = spatial_attn_skip_range + self._pab_temporal_attn_skip_range = temporal_attn_skip_range self._pab_cross_attn_skip_range = cross_attn_skip_range self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range + self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range self._pab_cross_attn_timestep_range = cross_attn_timestep_range - self._pab_enabled = spatial_attn_skip_range or cross_attn_skip_range + self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers + self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers + self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers + + self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range self._enable_pyramid_attention_broadcast() @@ -208,9 +160,14 @@ def disable_pyramid_attention_broadcast(self) -> None: r"""Disables the pyramid attention broadcast sampling mechanism.""" self._pab_spatial_attn_skip_range = None + self._pab_temporal_attn_skip_range = None self._pab_cross_attn_skip_range = None self._pab_spatial_attn_timestep_range = None + self._pab_temporal_attn_timestep_range = None self._pab_cross_attn_timestep_range = None + self._pab_spatial_attn_layer_identifiers = None + self._pab_temporal_attn_layer_identifiers = None + self._pab_cross_attn_layer_identifiers = None self._pab_enabled = False @property