Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 144 additions & 17 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Union
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch

from ..configuration_utils import register_to_config
from ..hooks import LayerSkipConfig
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .skip_layer_guidance import SkipLayerGuidance
from .guider_utils import BaseGuidance, rescale_noise_cfg


if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState


logger = get_logger(__name__) # pylint: disable=invalid-name


class PerturbedAttentionGuidance(SkipLayerGuidance):
class PerturbedAttentionGuidance(BaseGuidance):
"""
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377

Expand All @@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)

PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
and implementation details.

Args:
Expand Down Expand Up @@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
# for each model architecture.

_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]

@register_to_config
def __init__(
self,
Expand All @@ -89,6 +99,15 @@ def __init__(
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)

self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
self.skip_layer_guidance_start = perturbed_guidance_start
self.skip_layer_guidance_stop = perturbed_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation

if perturbed_guidance_config is None:
if perturbed_guidance_layers is None:
raise ValueError(
Expand Down Expand Up @@ -130,15 +149,123 @@ def __init__(
config.skip_attention_scores = True
config.skip_ff = False

super().__init__(
guidance_scale=guidance_scale,
skip_layer_guidance_scale=perturbed_guidance_scale,
skip_layer_guidance_start=perturbed_guidance_start,
skip_layer_guidance_stop=perturbed_guidance_stop,
skip_layer_guidance_layers=perturbed_guidance_layers,
skip_layer_config=perturbed_guidance_config,
guidance_rescale=guidance_rescale,
use_original_formulation=use_original_formulation,
start=start,
stop=stop,
)
self.skip_layer_config = perturbed_guidance_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields

if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None

if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip

if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}

@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3

@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False

is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step

is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)

return is_within_range and not is_close

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False

is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step

is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)

return is_within_range and not is_zero
2 changes: 1 addition & 1 deletion src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def init_pipeline(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
):
) -> "ModularPipeline":
"""
create a ModularPipeline, optionally accept modular_repo to load from hub.
"""
Expand Down
Loading