Skip to content

Commit 77d8a28

Browse files
committed
cfg plus plus
1 parent 2dc673a commit 77d8a28

File tree

7 files changed

+163
-5
lines changed

7 files changed

+163
-5
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
[
135135
"AdaptiveProjectedGuidance",
136136
"AutoGuidance",
137+
"CFGPlusPlusGuidance",
137138
"ClassifierFreeGuidance",
138139
"ClassifierFreeZeroStarGuidance",
139140
"SkipLayerGuidance",
@@ -729,6 +730,7 @@
729730
from .guiders import (
730731
AdaptiveProjectedGuidance,
731732
AutoGuidance,
733+
CFGPlusPlusGuidance,
732734
ClassifierFreeGuidance,
733735
ClassifierFreeZeroStarGuidance,
734736
SkipLayerGuidance,

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if is_torch_available():
2121
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
2222
from .auto_guidance import AutoGuidance
23+
from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance
2324
from .classifier_free_guidance import ClassifierFreeGuidance
2425
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2526
from .skip_layer_guidance import SkipLayerGuidance
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Union, Tuple, List
17+
18+
import torch
19+
20+
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
21+
22+
23+
class CFGPlusPlusGuidance(BaseGuidance):
24+
"""
25+
CFG++: https://huggingface.co/papers/2406.08070
26+
27+
Args:
28+
guidance_scale (`float`, defaults to `0.7`):
29+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
30+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
31+
deterioration of image quality.
32+
guidance_rescale (`float`, defaults to `0.0`):
33+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
34+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
35+
Flawed](https://huggingface.co/papers/2305.08891).
36+
use_original_formulation (`bool`, defaults to `False`):
37+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
38+
we use the diffusers-native implementation that has been in the codebase for a long time. See
39+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
40+
start (`float`, defaults to `0.0`):
41+
The fraction of the total number of denoising steps after which guidance starts.
42+
stop (`float`, defaults to `1.0`):
43+
The fraction of the total number of denoising steps after which guidance stops.
44+
"""
45+
46+
_input_predictions = ["pred_cond", "pred_uncond"]
47+
48+
def __init__(
49+
self,
50+
guidance_scale: float = 0.7,
51+
guidance_rescale: float = 0.0,
52+
use_original_formulation: bool = False,
53+
start: float = 0.0,
54+
stop: float = 1.0,
55+
):
56+
super().__init__(start, stop)
57+
58+
self.guidance_scale = guidance_scale
59+
self.guidance_rescale = guidance_rescale
60+
self.use_original_formulation = use_original_formulation
61+
62+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
63+
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
64+
65+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
66+
self._num_outputs_prepared += 1
67+
if self._num_outputs_prepared > self.num_conditions:
68+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
69+
key = self._input_predictions[self._num_outputs_prepared - 1]
70+
self._preds[key] = pred
71+
72+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
73+
pred = None
74+
75+
if not self._is_cfgpp_enabled():
76+
pred = pred_cond
77+
else:
78+
shift = pred_cond - pred_uncond
79+
pred = pred_cond if self.use_original_formulation else pred_uncond
80+
pred = pred + self.guidance_scale * shift
81+
82+
if self.guidance_rescale > 0.0:
83+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
84+
85+
return pred
86+
87+
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
88+
if self._is_cfgpp_enabled():
89+
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
90+
pred_cond = self._preds["pred_cond"]
91+
pred_uncond = self._preds["pred_uncond"]
92+
diff = pred_uncond - pred_cond
93+
pred = pred + diff * self.guidance_scale * self._sigma_next
94+
return pred
95+
96+
@property
97+
def is_conditional(self) -> bool:
98+
return self._num_outputs_prepared == 0
99+
100+
@property
101+
def num_conditions(self) -> int:
102+
num_conditions = 1
103+
if self._is_cfgpp_enabled():
104+
num_conditions += 1
105+
return num_conditions
106+
107+
def _is_cfgpp_enabled(self) -> bool:
108+
if not self._enabled:
109+
return False
110+
111+
is_within_range = True
112+
if self._num_inference_steps is not None:
113+
skip_start_step = int(self._start * self._num_inference_steps)
114+
skip_stop_step = int(self._stop * self._num_inference_steps)
115+
is_within_range = skip_start_step <= self._step < skip_stop_step
116+
117+
return is_within_range

src/diffusers/guiders/guider_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
3737
self._step: int = None
3838
self._num_inference_steps: int = None
3939
self._timestep: torch.LongTensor = None
40+
self._sigma: torch.Tensor = None
41+
self._sigma_next: torch.Tensor = None
4042
self._preds: Dict[str, torch.Tensor] = {}
4143
self._num_outputs_prepared: int = 0
4244
self._enabled = True
@@ -61,10 +63,12 @@ def _force_disable(self):
6163
def _force_enable(self):
6264
self._enabled = True
6365

64-
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
66+
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None:
6567
self._step = step
6668
self._num_inference_steps = num_inference_steps
6769
self._timestep = timestep
70+
self._sigma = sigma
71+
self._sigma_next = sigma_next
6872
self._preds = {}
6973
self._num_outputs_prepared = 0
7074

@@ -91,6 +95,9 @@ def __call__(self, **kwargs) -> Any:
9195
def forward(self, *args, **kwargs) -> Any:
9296
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
9397

98+
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
99+
return pred
100+
94101
@property
95102
def is_conditional(self) -> bool:
96103
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
self.guidance_scale = guidance_scale
5959
self.guidance_rescale = guidance_rescale
6060
self.use_original_formulation = use_original_formulation
61-
self.momentum_buffer = None
6261

6362
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
6463
return _default_prepare_inputs(denoiser, self.num_conditions, *args)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,7 +2241,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
22412241

22422242
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
22432243
for i, t in enumerate(data.timesteps):
2244-
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
2244+
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
22452245

22462246
(
22472247
latents,
@@ -2301,6 +2301,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
23012301
# Perform scheduler step using the predicted output
23022302
data.latents_dtype = data.latents.dtype
23032303
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
2304+
data.latents = pipeline.guider.post_scheduler_step(data.latents)
23042305

23052306
if data.latents.dtype != data.latents_dtype:
23062307
if torch.backends.mps.is_available():
@@ -2637,7 +2638,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
26372638
# (5) Denoise loop
26382639
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
26392640
for i, t in enumerate(data.timesteps):
2640-
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
2641+
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
26412642

26422643
(
26432644
latents,
@@ -2730,6 +2731,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
27302731
# Perform scheduler step using the predicted output
27312732
data.latents_dtype = data.latents.dtype
27322733
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
2734+
data.latents = pipeline.guider.post_scheduler_step(data.latents)
27332735

27342736
if data.latents.dtype != data.latents_dtype:
27352737
if torch.backends.mps.is_available():
@@ -3053,7 +3055,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
30533055

30543056
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
30553057
for i, t in enumerate(data.timesteps):
3056-
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
3058+
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
30573059

30583060
(
30593061
latents,
@@ -3148,6 +3150,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
31483150
# Perform scheduler step using the predicted output
31493151
data.latents_dtype = data.latents.dtype
31503152
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
3153+
data.latents = pipeline.guider.post_scheduler_step(data.latents)
31513154

31523155
if data.latents.dtype != data.latents_dtype:
31533156
if torch.backends.mps.is_available():

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,35 @@ def step(
669669

670670
prev_sample = sample + derivative * dt
671671

672+
# denoised = sample - model_output * sigmas[i]
673+
# d = (sample - denoised) / sigmas[i]
674+
# new_sample = denoised + d * sigmas[i + 1]
675+
676+
# new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i]
677+
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1]
678+
# new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i])
679+
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1)
680+
681+
# CFG++ =====
682+
# denoised = sample - model_output * sigmas[i]
683+
# uncond_denoised = sample - model_output_uncond * sigmas[i]
684+
# d = (sample - uncond_denoised) / sigmas[i]
685+
# new_sample = denoised + d * sigmas[i + 1]
686+
687+
# new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i]
688+
# new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2)
689+
690+
# To go from (1) to (2):
691+
# new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1]
692+
# new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1]
693+
# new_sample_2 = new_sample_1 + diff * sigmas[i + 1]
694+
695+
# diff = model_output_uncond - model_output
696+
# diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond))
697+
# diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond)
698+
# diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond
699+
# diff = g * (model_output_uncond - model_output_cond)
700+
672701
# Cast sample back to model compatible dtype
673702
prev_sample = prev_sample.to(model_output.dtype)
674703

0 commit comments

Comments
 (0)