diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 3045f2feaae2..1b2256732ffc 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch from ..configuration_utils import register_to_config -from ..hooks import LayerSkipConfig +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook from ..utils import get_logger -from .skip_layer_guidance import SkipLayerGuidance +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name -class PerturbedAttentionGuidance(SkipLayerGuidance): +class PerturbedAttentionGuidance(BaseGuidance): """ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 @@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) - PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters + PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters and implementation details. Args: @@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation # for each model architecture. + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + @register_to_config def __init__( self, @@ -89,6 +99,15 @@ def __init__( start: float = 0.0, stop: float = 1.0, ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = perturbed_guidance_scale + self.skip_layer_guidance_start = perturbed_guidance_start + self.skip_layer_guidance_stop = perturbed_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + if perturbed_guidance_config is None: if perturbed_guidance_layers is None: raise ValueError( @@ -130,15 +149,123 @@ def __init__( config.skip_attention_scores = True config.skip_ff = False - super().__init__( - guidance_scale=guidance_scale, - skip_layer_guidance_scale=perturbed_guidance_scale, - skip_layer_guidance_start=perturbed_guidance_start, - skip_layer_guidance_stop=perturbed_guidance_stop, - skip_layer_guidance_layers=perturbed_guidance_layers, - skip_layer_config=perturbed_guidance_config, - guidance_rescale=guidance_rescale, - use_original_formulation=use_original_formulation, - start=start, - stop=stop, - ) + self.skip_layer_config = perturbed_guidance_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5440e5e5a6ff..57af0f220765 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -335,7 +335,7 @@ def init_pipeline( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, - ): + ) -> "ModularPipeline": """ create a ModularPipeline, optionally accept modular_repo to load from hub. """