Skip to content

Commit 357f4f0

Browse files
committed
update
1 parent 53b6b9f commit 357f4f0

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
3939
Flawed](https://huggingface.co/papers/2305.08891).
4040
use_original_formulation (`bool`, defaults to `False`):
4141
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
42-
we use the diffusers-native implementation that has been in the codebase for a long time.
42+
we use the diffusers-native implementation that has been in the codebase for a long time. See
43+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
4344
"""
4445

4546
_input_predictions = ["pred_cond", "pred_uncond"]

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ class ClassifierFreeGuidance(GuidanceMixin):
3131
The original paper proposes scaling and shifting the conditional distribution based on the difference between
3232
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
3333
34-
Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what
35-
the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
34+
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
35+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
36+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
3637
3738
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
3839
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
@@ -53,7 +54,8 @@ class ClassifierFreeGuidance(GuidanceMixin):
5354
Flawed](https://huggingface.co/papers/2305.08891).
5455
use_original_formulation (`bool`, defaults to `False`):
5556
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
56-
we use the diffusers-native implementation that has been in the codebase for a long time.
57+
we use the diffusers-native implementation that has been in the codebase for a long time. See
58+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
5759
"""
5860

5961
_input_predictions = ["pred_cond", "pred_uncond"]

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,26 +109,26 @@ def prepare_outputs(self, pred: torch.Tensor) -> None:
109109
key = "pred_perturbed"
110110
self._preds[key] = pred
111111

112-
# Prepare denoiser for perturbed attention prediction if needed
113-
if not self._is_pag_enabled():
114-
return
115-
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
116-
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
117-
)
118-
if should_register_pag:
119-
self._is_pag_batch = True
120-
self._original_processors = _replace_attention_processors(
121-
self._denoiser,
122-
self.pag_applied_layers,
123-
skip_context_attention=self.skip_context_attention,
124-
metadata_name="perturbed_attention_guidance_processor_cls",
125-
)
126-
elif self._is_pag_batch:
127-
# Restore the original attention processors
112+
# Restore the original attention processors if previously replaced
113+
if self._is_pag_batch:
128114
_replace_attention_processors(self._denoiser, processors=self._original_processors)
129115
self._is_pag_batch = False
130116
self._original_processors = None
131117

118+
# Prepare denoiser for perturbed attention prediction if needed
119+
if self._is_pag_enabled():
120+
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
121+
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
122+
)
123+
if should_register_pag:
124+
self._is_pag_batch = True
125+
self._original_processors = _replace_attention_processors(
126+
self._denoiser,
127+
self.pag_applied_layers,
128+
skip_context_attention=self.skip_context_attention,
129+
metadata_name="perturbed_attention_guidance_processor_cls",
130+
)
131+
132132
def cleanup_models(self, denoiser: torch.nn.Module):
133133
self._denoiser = None
134134

0 commit comments

Comments
 (0)