Skip to content

Commit 53b6b9f

Browse files
committed
perturbed attention guidance
1 parent 4664356 commit 53b6b9f

File tree

8 files changed

+400
-17
lines changed

8 files changed

+400
-17
lines changed

src/diffusers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@
131131

132132
else:
133133
_import_structure["guiders"].extend(
134-
["AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"]
134+
[
135+
"AdaptiveProjectedGuidance",
136+
"ClassifierFreeGuidance",
137+
"ClassifierFreeZeroStarGuidance",
138+
"PerturbedAttentionGuidance",
139+
"SkipLayerGuidance",
140+
]
135141
)
136142
_import_structure["hooks"].extend(
137143
[
@@ -720,6 +726,7 @@
720726
AdaptiveProjectedGuidance,
721727
ClassifierFreeGuidance,
722728
ClassifierFreeZeroStarGuidance,
729+
PerturbedAttentionGuidance,
723730
SkipLayerGuidance,
724731
)
725732
from .hooks import (

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
from .classifier_free_guidance import ClassifierFreeGuidance
2121
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2222
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
23+
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2324
from .skip_layer_guidance import SkipLayerGuidance

src/diffusers/guiders/guider_utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple, Union
15+
import re
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1617

1718
import torch
1819

1920
from ..utils import deprecate, get_logger
2021

2122

23+
if TYPE_CHECKING:
24+
from ..models.attention_processor import AttentionProcessor
25+
26+
2227
logger = get_logger(__name__) # pylint: disable=invalid-name
2328

2429

@@ -129,6 +134,72 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
129134
return noise_cfg
130135

131136

137+
def _replace_attention_processors(
138+
module: torch.nn.Module,
139+
pag_applied_layers: Optional[Union[str, List[str]]] = None,
140+
skip_context_attention: bool = False,
141+
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
142+
metadata_name: Optional[str] = None,
143+
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
144+
if processors is not None and metadata_name is not None:
145+
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
146+
if metadata_name is not None:
147+
if isinstance(pag_applied_layers, str):
148+
pag_applied_layers = [pag_applied_layers]
149+
return _replace_layers_with_guidance_processors(
150+
module, pag_applied_layers, skip_context_attention, metadata_name
151+
)
152+
if processors is not None:
153+
_replace_layers_with_existing_processors(processors)
154+
155+
156+
def _replace_layers_with_guidance_processors(
157+
module: torch.nn.Module,
158+
pag_applied_layers: List[str],
159+
skip_context_attention: bool,
160+
metadata_name: str,
161+
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
162+
from ..hooks._common import _ATTENTION_CLASSES
163+
from ..hooks._helpers import GuidanceMetadataRegistry
164+
165+
processors = []
166+
for name, submodule in module.named_modules():
167+
if (
168+
(not isinstance(submodule, _ATTENTION_CLASSES))
169+
or (getattr(submodule, "processor", None) is None)
170+
or not (
171+
any(
172+
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
173+
for pag_layer in pag_applied_layers
174+
)
175+
)
176+
):
177+
continue
178+
old_attention_processor = submodule.processor
179+
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
180+
new_attention_processor_cls = getattr(metadata, metadata_name)
181+
new_attention_processor = new_attention_processor_cls()
182+
# !!! dunder methods cannot be replaced on instances !!!
183+
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
184+
# new_attention_processor.__call__ = partial(
185+
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
186+
# )
187+
submodule.processor = new_attention_processor
188+
processors.append((submodule, old_attention_processor))
189+
return processors
190+
191+
192+
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
193+
for module, proc in processors:
194+
module.processor = proc
195+
196+
197+
def _is_fake_integral_match(layer_id, name):
198+
layer_id = layer_id.split(".")[-1]
199+
name = name.split(".")[-1]
200+
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
201+
202+
132203
def _raise_guidance_deprecation_warning(
133204
*,
134205
guidance_scale: Optional[float] = None,
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import List, Optional, Tuple, Union
17+
18+
import torch
19+
20+
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
21+
22+
23+
class PerturbedAttentionGuidance(GuidanceMixin):
24+
"""
25+
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
26+
27+
Args:
28+
pag_applied_layers (`str` or `List[str]`):
29+
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
30+
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
31+
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
32+
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
33+
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
34+
guidance_scale (`float`, defaults to `7.5`):
35+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
36+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
37+
deterioration of image quality.
38+
pag_scale (`float`, defaults to `3.0`):
39+
The scale parameter for perturbed attention guidance.
40+
guidance_rescale (`float`, defaults to `0.0`):
41+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
42+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
43+
Flawed](https://huggingface.co/papers/2305.08891).
44+
use_original_formulation (`bool`, defaults to `False`):
45+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
46+
we use the diffusers-native implementation that has been in the codebase for a long time. See
47+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
48+
"""
49+
50+
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
51+
52+
def __init__(
53+
self,
54+
pag_applied_layers: Union[str, List[str]],
55+
guidance_scale: float = 7.5,
56+
pag_scale: float = 3.0,
57+
skip_context_attention: bool = False,
58+
guidance_rescale: float = 0.0,
59+
use_original_formulation: bool = False,
60+
):
61+
super().__init__()
62+
63+
self.pag_applied_layers = pag_applied_layers
64+
self.guidance_scale = guidance_scale
65+
self.pag_scale = pag_scale
66+
self.skip_context_attention = skip_context_attention
67+
self.guidance_rescale = guidance_rescale
68+
self.use_original_formulation = use_original_formulation
69+
70+
self._is_pag_batch = False
71+
self._original_processors = None
72+
self._denoiser = None
73+
74+
def prepare_models(self, denoiser: torch.nn.Module):
75+
self._denoiser = denoiser
76+
77+
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
78+
num_conditions = self.num_conditions
79+
list_of_inputs = []
80+
for arg in args:
81+
if isinstance(arg, torch.Tensor):
82+
list_of_inputs.append([arg] * num_conditions)
83+
elif isinstance(arg, (tuple, list)):
84+
if len(arg) != 2:
85+
raise ValueError(
86+
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
87+
f"with the first element being the conditional input and the second element being the unconditional input or None."
88+
)
89+
if arg[1] is None:
90+
# Only conditioning inputs for all batches
91+
list_of_inputs.append([arg[0]] * num_conditions)
92+
else:
93+
list_of_inputs.append([arg[0], arg[1], arg[0]])
94+
else:
95+
raise ValueError(
96+
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
97+
)
98+
return tuple(list_of_inputs)
99+
100+
def prepare_outputs(self, pred: torch.Tensor) -> None:
101+
self._num_outputs_prepared += 1
102+
if self._num_outputs_prepared > self.num_conditions:
103+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
104+
key = self._input_predictions[self._num_outputs_prepared - 1]
105+
if not self._is_cfg_enabled() and self._is_pag_enabled():
106+
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
107+
# to avoid writing into pred_uncond which is not used
108+
if self._num_outputs_prepared == 2:
109+
key = "pred_perturbed"
110+
self._preds[key] = pred
111+
112+
# Prepare denoiser for perturbed attention prediction if needed
113+
if not self._is_pag_enabled():
114+
return
115+
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
116+
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
117+
)
118+
if should_register_pag:
119+
self._is_pag_batch = True
120+
self._original_processors = _replace_attention_processors(
121+
self._denoiser,
122+
self.pag_applied_layers,
123+
skip_context_attention=self.skip_context_attention,
124+
metadata_name="perturbed_attention_guidance_processor_cls",
125+
)
126+
elif self._is_pag_batch:
127+
# Restore the original attention processors
128+
_replace_attention_processors(self._denoiser, processors=self._original_processors)
129+
self._is_pag_batch = False
130+
self._original_processors = None
131+
132+
def cleanup_models(self, denoiser: torch.nn.Module):
133+
self._denoiser = None
134+
135+
def forward(
136+
self,
137+
pred_cond: torch.Tensor,
138+
pred_uncond: Optional[torch.Tensor] = None,
139+
pred_perturbed: Optional[torch.Tensor] = None,
140+
) -> torch.Tensor:
141+
pred = None
142+
143+
if not self._is_cfg_enabled() and not self._is_pag_enabled():
144+
pred = pred_cond
145+
elif not self._is_cfg_enabled():
146+
shift = pred_cond - pred_perturbed
147+
pred = pred_cond + self.pag_scale * shift
148+
elif not self._is_pag_enabled():
149+
shift = pred_cond - pred_uncond
150+
pred = pred_cond if self.use_original_formulation else pred_uncond
151+
pred = pred + self.guidance_scale * shift
152+
else:
153+
shift = pred_cond - pred_uncond
154+
shift_perturbed = pred_cond - pred_perturbed
155+
pred = pred_cond if self.use_original_formulation else pred_uncond
156+
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
157+
158+
if self.guidance_rescale > 0.0:
159+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
160+
161+
return pred
162+
163+
@property
164+
def num_conditions(self) -> int:
165+
num_conditions = 1
166+
if self._is_cfg_enabled():
167+
num_conditions += 1
168+
if self._is_pag_enabled():
169+
num_conditions += 1
170+
return num_conditions
171+
172+
def _is_cfg_enabled(self) -> bool:
173+
if self.use_original_formulation:
174+
return not math.isclose(self.guidance_scale, 0.0)
175+
else:
176+
return not math.isclose(self.guidance_scale, 1.0)
177+
178+
def _is_pag_enabled(self) -> bool:
179+
is_zero = math.isclose(self.pag_scale, 0.0)
180+
return not is_zero

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -186,30 +186,22 @@ def forward(
186186
pred_cond_skip: Optional[torch.Tensor] = None,
187187
) -> torch.Tensor:
188188
pred = None
189-
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
190-
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
191189

192190
if not self._is_cfg_enabled() and not self._is_slg_enabled():
193191
pred = pred_cond
194192
elif not self._is_cfg_enabled():
195-
if skip_start_step < self._step < skip_stop_step:
196-
shift = pred_cond - pred_cond_skip
197-
pred = pred_cond if self.use_original_formulation else pred_cond_skip
198-
pred = pred + self.skip_layer_guidance_scale * shift
199-
else:
200-
pred = pred_cond
193+
shift = pred_cond - pred_cond_skip
194+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
195+
pred = pred + self.skip_layer_guidance_scale * shift
201196
elif not self._is_slg_enabled():
202197
shift = pred_cond - pred_uncond
203198
pred = pred_cond if self.use_original_formulation else pred_uncond
204199
pred = pred + self.guidance_scale * shift
205200
else:
206201
shift = pred_cond - pred_uncond
202+
shift_skip = pred_cond - pred_cond_skip
207203
pred = pred_cond if self.use_original_formulation else pred_uncond
208-
pred = pred + self.guidance_scale * shift
209-
210-
if skip_start_step < self._step < skip_stop_step:
211-
shift_skip = pred_cond - pred_cond_skip
212-
pred = pred + self.skip_layer_guidance_scale * shift_skip
204+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
213205

214206
if self.guidance_rescale > 0.0:
215207
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
@@ -234,4 +226,6 @@ def _is_cfg_enabled(self) -> bool:
234226
def _is_slg_enabled(self) -> bool:
235227
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
236228
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
237-
return skip_start_step < self._step < skip_stop_step
229+
is_within_range = skip_start_step < self._step < skip_stop_step
230+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
231+
return is_within_range and not is_zero

0 commit comments

Comments
 (0)