Skip to content

Commit 25804b7

Browse files
committed
update
1 parent e0083b2 commit 25804b7

File tree

3 files changed

+149
-18
lines changed

3 files changed

+149
-18
lines changed

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Union
15+
import math
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17+
18+
import torch
1619

1720
from ..configuration_utils import register_to_config
18-
from ..hooks import LayerSkipConfig
21+
from ..hooks import HookRegistry, LayerSkipConfig
22+
from ..hooks.layer_skip import _apply_layer_skip_hook
1923
from ..utils import get_logger
20-
from .skip_layer_guidance import SkipLayerGuidance
24+
from .guider_utils import BaseGuidance, rescale_noise_cfg
25+
26+
27+
if TYPE_CHECKING:
28+
from ..modular_pipelines.modular_pipeline import BlockState
2129

2230

2331
logger = get_logger(__name__) # pylint: disable=invalid-name
2432

2533

26-
class PerturbedAttentionGuidance(SkipLayerGuidance):
34+
class PerturbedAttentionGuidance(BaseGuidance):
2735
"""
2836
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
2937
@@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
3644
Additional reading:
3745
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
3846
39-
PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters
47+
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
4048
and implementation details.
4149
4250
Args:
@@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
7583
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
7684
# for each model architecture.
7785

86+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
87+
7888
@register_to_config
7989
def __init__(
8090
self,
@@ -89,6 +99,15 @@ def __init__(
8999
start: float = 0.0,
90100
stop: float = 1.0,
91101
):
102+
super().__init__(start, stop)
103+
104+
self.guidance_scale = guidance_scale
105+
self.skip_layer_guidance_scale = perturbed_guidance_scale
106+
self.skip_layer_guidance_start = perturbed_guidance_start
107+
self.skip_layer_guidance_stop = perturbed_guidance_stop
108+
self.guidance_rescale = guidance_rescale
109+
self.use_original_formulation = use_original_formulation
110+
92111
if perturbed_guidance_config is None:
93112
if perturbed_guidance_layers is None:
94113
raise ValueError(
@@ -130,15 +149,123 @@ def __init__(
130149
config.skip_attention_scores = True
131150
config.skip_ff = False
132151

133-
super().__init__(
134-
guidance_scale=guidance_scale,
135-
skip_layer_guidance_scale=perturbed_guidance_scale,
136-
skip_layer_guidance_start=perturbed_guidance_start,
137-
skip_layer_guidance_stop=perturbed_guidance_stop,
138-
skip_layer_guidance_layers=perturbed_guidance_layers,
139-
skip_layer_config=perturbed_guidance_config,
140-
guidance_rescale=guidance_rescale,
141-
use_original_formulation=use_original_formulation,
142-
start=start,
143-
stop=stop,
144-
)
152+
self.skip_layer_config = perturbed_guidance_config
153+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
154+
155+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
156+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
157+
self._count_prepared += 1
158+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
159+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
160+
_apply_layer_skip_hook(denoiser, config, name=name)
161+
162+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
163+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
164+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
165+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
166+
# Remove the hooks after inference
167+
for hook_name in self._skip_layer_hook_names:
168+
registry.remove_hook(hook_name, recurse=True)
169+
170+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
171+
def prepare_inputs(
172+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
173+
) -> List["BlockState"]:
174+
if input_fields is None:
175+
input_fields = self._input_fields
176+
177+
if self.num_conditions == 1:
178+
tuple_indices = [0]
179+
input_predictions = ["pred_cond"]
180+
elif self.num_conditions == 2:
181+
tuple_indices = [0, 1]
182+
input_predictions = (
183+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
184+
)
185+
else:
186+
tuple_indices = [0, 1, 0]
187+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
188+
data_batches = []
189+
for i in range(self.num_conditions):
190+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
191+
data_batches.append(data_batch)
192+
return data_batches
193+
194+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
195+
def forward(
196+
self,
197+
pred_cond: torch.Tensor,
198+
pred_uncond: Optional[torch.Tensor] = None,
199+
pred_cond_skip: Optional[torch.Tensor] = None,
200+
) -> torch.Tensor:
201+
pred = None
202+
203+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
204+
pred = pred_cond
205+
elif not self._is_cfg_enabled():
206+
shift = pred_cond - pred_cond_skip
207+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
208+
pred = pred + self.skip_layer_guidance_scale * shift
209+
elif not self._is_slg_enabled():
210+
shift = pred_cond - pred_uncond
211+
pred = pred_cond if self.use_original_formulation else pred_uncond
212+
pred = pred + self.guidance_scale * shift
213+
else:
214+
shift = pred_cond - pred_uncond
215+
shift_skip = pred_cond - pred_cond_skip
216+
pred = pred_cond if self.use_original_formulation else pred_uncond
217+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
218+
219+
if self.guidance_rescale > 0.0:
220+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
221+
222+
return pred, {}
223+
224+
@property
225+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
226+
def is_conditional(self) -> bool:
227+
return self._count_prepared == 1 or self._count_prepared == 3
228+
229+
@property
230+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
231+
def num_conditions(self) -> int:
232+
num_conditions = 1
233+
if self._is_cfg_enabled():
234+
num_conditions += 1
235+
if self._is_slg_enabled():
236+
num_conditions += 1
237+
return num_conditions
238+
239+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
240+
def _is_cfg_enabled(self) -> bool:
241+
if not self._enabled:
242+
return False
243+
244+
is_within_range = True
245+
if self._num_inference_steps is not None:
246+
skip_start_step = int(self._start * self._num_inference_steps)
247+
skip_stop_step = int(self._stop * self._num_inference_steps)
248+
is_within_range = skip_start_step <= self._step < skip_stop_step
249+
250+
is_close = False
251+
if self.use_original_formulation:
252+
is_close = math.isclose(self.guidance_scale, 0.0)
253+
else:
254+
is_close = math.isclose(self.guidance_scale, 1.0)
255+
256+
return is_within_range and not is_close
257+
258+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
259+
def _is_slg_enabled(self) -> bool:
260+
if not self._enabled:
261+
return False
262+
263+
is_within_range = True
264+
if self._num_inference_steps is not None:
265+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
266+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
267+
is_within_range = skip_start_step < self._step < skip_stop_step
268+
269+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
270+
271+
return is_within_range and not is_zero

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23+
from ..utils import get_logger
2324
from .guider_utils import BaseGuidance, rescale_noise_cfg
2425

2526

2627
if TYPE_CHECKING:
2728
from ..modular_pipelines.modular_pipeline import BlockState
2829

2930

31+
logger = get_logger(__name__) # pylint: disable=invalid-name
32+
33+
3034
class SkipLayerGuidance(BaseGuidance):
3135
"""
3236
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def init_pipeline(
335335
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
336336
components_manager: Optional[ComponentsManager] = None,
337337
collection: Optional[str] = None,
338-
):
338+
) -> "ModularPipeline":
339339
"""
340340
create a ModularPipeline, optionally accept modular_repo to load from hub.
341341
"""

0 commit comments

Comments
 (0)