Skip to content

Commit 2238f55

Browse files
committed
cfg zero*
1 parent 6255302 commit 2238f55

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
[
135135
"AdaptiveProjectedGuidance",
136136
"ClassifierFreeGuidance",
137+
"ClassifierFreeZeroStarGuidance",
137138
"SkipLayerGuidance",
138139
]
139140
)
@@ -724,6 +725,7 @@
724725
from .guiders import (
725726
AdaptiveProjectedGuidance,
726727
ClassifierFreeGuidance,
728+
ClassifierFreeZeroStarGuidance,
727729
SkipLayerGuidance,
728730
)
729731
from .hooks import (

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if is_torch_available():
2121
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
2222
from .classifier_free_guidance import ClassifierFreeGuidance
23+
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2324
from .skip_layer_guidance import SkipLayerGuidance
2425

2526
GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance]

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,25 @@
2323
class ClassifierFreeGuidance(BaseGuidance):
2424
"""
2525
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
26+
2627
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
2728
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
2829
inference. This allows the model to tradeoff between generation quality and sample diversity.
2930
The original paper proposes scaling and shifting the conditional distribution based on the difference between
3031
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
32+
3133
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
3234
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
3335
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
36+
3437
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
3538
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
3639
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
3740
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
41+
3842
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
3943
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
44+
4045
Args:
4146
guidance_scale (`float`, defaults to `7.5`):
4247
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Union, Tuple, List
17+
18+
import torch
19+
20+
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
21+
22+
23+
class ClassifierFreeZeroStarGuidance(BaseGuidance):
24+
"""
25+
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
26+
27+
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
28+
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
29+
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
30+
quality of generated images.
31+
32+
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
33+
34+
Args:
35+
guidance_scale (`float`, defaults to `7.5`):
36+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
37+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
38+
deterioration of image quality.
39+
zero_init_steps (`int`, defaults to `1`):
40+
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
41+
guidance_rescale (`float`, defaults to `0.0`):
42+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
43+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
44+
Flawed](https://huggingface.co/papers/2305.08891).
45+
use_original_formulation (`bool`, defaults to `False`):
46+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
47+
we use the diffusers-native implementation that has been in the codebase for a long time. See
48+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
49+
start (`float`, defaults to `0.01`):
50+
The fraction of the total number of denoising steps after which guidance starts.
51+
stop (`float`, defaults to `0.2`):
52+
The fraction of the total number of denoising steps after which guidance stops.
53+
"""
54+
55+
_input_predictions = ["pred_cond", "pred_uncond"]
56+
57+
def __init__(
58+
self,
59+
guidance_scale: float = 7.5,
60+
zero_init_steps: int = 1,
61+
guidance_rescale: float = 0.0,
62+
use_original_formulation: bool = False,
63+
start: float = 0.0,
64+
stop: float = 1.0,
65+
):
66+
super().__init__(start, stop)
67+
68+
self.guidance_scale = guidance_scale
69+
self.zero_init_steps = zero_init_steps
70+
self.guidance_rescale = guidance_rescale
71+
self.use_original_formulation = use_original_formulation
72+
73+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
74+
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
75+
76+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
77+
self._num_outputs_prepared += 1
78+
if self._num_outputs_prepared > self.num_conditions:
79+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
80+
key = self._input_predictions[self._num_outputs_prepared - 1]
81+
self._preds[key] = pred
82+
83+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
84+
pred = None
85+
86+
if self._step < self.zero_init_steps:
87+
pred = torch.zeros_like(pred_cond)
88+
elif not self._is_cfg_enabled():
89+
pred = pred_cond
90+
else:
91+
pred_cond_flat = pred_cond.flatten(1)
92+
pred_uncond_flat = pred_uncond.flatten(1)
93+
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
94+
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
95+
pred_uncond = pred_uncond * alpha
96+
shift = pred_cond - pred_uncond
97+
pred = pred_cond if self.use_original_formulation else pred_uncond
98+
pred = pred + self.guidance_scale * shift
99+
100+
if self.guidance_rescale > 0.0:
101+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
102+
103+
return pred
104+
105+
@property
106+
def is_conditional(self) -> bool:
107+
return self._num_outputs_prepared == 0
108+
109+
@property
110+
def num_conditions(self) -> int:
111+
num_conditions = 1
112+
if self._is_cfg_enabled():
113+
num_conditions += 1
114+
return num_conditions
115+
116+
def _is_cfg_enabled(self) -> bool:
117+
if not self._enabled:
118+
return False
119+
120+
is_within_range = True
121+
if self._num_inference_steps is not None:
122+
skip_start_step = int(self._start * self._num_inference_steps)
123+
skip_stop_step = int(self._stop * self._num_inference_steps)
124+
is_within_range = skip_start_step <= self._step < skip_stop_step
125+
126+
is_close = False
127+
if self.use_original_formulation:
128+
is_close = math.isclose(self.guidance_scale, 0.0)
129+
else:
130+
is_close = math.isclose(self.guidance_scale, 1.0)
131+
132+
return is_within_range and not is_close
133+
134+
135+
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
136+
cond_dtype = cond.dtype
137+
cond = cond.float()
138+
uncond = uncond.float()
139+
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
140+
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
141+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
142+
scale = dot_product / squared_norm
143+
return scale.to(dtype=cond_dtype)

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,26 @@ class SkipLayerGuidance(BaseGuidance):
2626
"""
2727
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
2828
https://huggingface.co/papers/2411.18664
29+
2930
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
3031
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
3132
batch of data, apart from the conditional and unconditional batches already used in CFG
3233
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
3334
based on the difference between conditional without skipping and conditional with skipping predictions.
35+
3436
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
3537
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
3638
version of the model for the conditional prediction).
39+
3740
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
3841
generation quality in video diffusion models.
42+
3943
Additional reading:
4044
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
45+
4146
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
4247
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
48+
4349
Args:
4450
guidance_scale (`float`, defaults to `7.5`):
4551
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text

0 commit comments

Comments
 (0)