Skip to content

Commit 2dc673a

Browse files
committed
tangential cfg
1 parent b9bcd46 commit 2dc673a

File tree

6 files changed

+148
-4
lines changed

6 files changed

+148
-4
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
"ClassifierFreeZeroStarGuidance",
139139
"SkipLayerGuidance",
140140
"SmoothedEnergyGuidance",
141+
"TangentialClassifierFreeGuidance",
141142
]
142143
)
143144
_import_structure["hooks"].extend(
@@ -732,6 +733,7 @@
732733
ClassifierFreeZeroStarGuidance,
733734
SkipLayerGuidance,
734735
SmoothedEnergyGuidance,
736+
TangentialClassifierFreeGuidance,
735737
)
736738
from .hooks import (
737739
FasterCacheConfig,

src/diffusers/guiders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@
2424
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2525
from .skip_layer_guidance import SkipLayerGuidance
2626
from .smoothed_energy_guidance import SmoothedEnergyGuidance
27+
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
2728

28-
GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance]
29+
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,25 @@ def normalized_guidance(
155155
):
156156
diff = pred_cond - pred_uncond
157157
dim = [-i for i in range(1, len(diff.shape))]
158+
158159
if momentum_buffer is not None:
159160
momentum_buffer.update(diff)
160161
diff = momentum_buffer.running_average
162+
161163
if norm_threshold > 0:
162164
ones = torch.ones_like(diff)
163165
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
164166
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
165167
diff = diff * scale_factor
168+
166169
v0, v1 = diff.double(), pred_cond.double()
167170
v1 = torch.nn.functional.normalize(v1, dim=dim)
168171
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
169172
v0_orthogonal = v0 - v0_parallel
170173
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
171174
normalized_update = diff_orthogonal + eta * diff_parallel
175+
172176
pred = pred_cond if use_original_formulation else pred_uncond
173-
pred = pred + (guidance_scale - 1) * normalized_update
177+
pred = pred + guidance_scale * normalized_update
178+
174179
return pred

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ class SmoothedEnergyGuidance(BaseGuidance):
2727
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
2828
2929
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
30-
in the future without warning or guarantee of reproducibility.
30+
in the future without warning or guarantee of reproducibility. This implementation assumes:
31+
- Generated images are square (height == width)
32+
- The model does not combine different modalities together (e.g., text and image latent streams are
33+
not combined together such as Flux)
3134
3235
Args:
3336
guidance_scale (`float`, defaults to `7.5`):
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Union, Tuple, List
17+
18+
import torch
19+
20+
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
21+
22+
23+
class TangentialClassifierFreeGuidance(BaseGuidance):
24+
"""
25+
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
26+
27+
Args:
28+
guidance_scale (`float`, defaults to `7.5`):
29+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
30+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
31+
deterioration of image quality.
32+
guidance_rescale (`float`, defaults to `0.0`):
33+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
34+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
35+
Flawed](https://huggingface.co/papers/2305.08891).
36+
use_original_formulation (`bool`, defaults to `False`):
37+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
38+
we use the diffusers-native implementation that has been in the codebase for a long time. See
39+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
40+
start (`float`, defaults to `0.0`):
41+
The fraction of the total number of denoising steps after which guidance starts.
42+
stop (`float`, defaults to `1.0`):
43+
The fraction of the total number of denoising steps after which guidance stops.
44+
"""
45+
46+
_input_predictions = ["pred_cond", "pred_uncond"]
47+
48+
def __init__(
49+
self,
50+
guidance_scale: float = 7.5,
51+
guidance_rescale: float = 0.0,
52+
use_original_formulation: bool = False,
53+
start: float = 0.0,
54+
stop: float = 1.0,
55+
):
56+
super().__init__(start, stop)
57+
58+
self.guidance_scale = guidance_scale
59+
self.guidance_rescale = guidance_rescale
60+
self.use_original_formulation = use_original_formulation
61+
self.momentum_buffer = None
62+
63+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
64+
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
65+
66+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
67+
self._num_outputs_prepared += 1
68+
if self._num_outputs_prepared > self.num_conditions:
69+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
70+
key = self._input_predictions[self._num_outputs_prepared - 1]
71+
self._preds[key] = pred
72+
73+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
74+
pred = None
75+
76+
if not self._is_tcfg_enabled():
77+
pred = pred_cond
78+
else:
79+
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
80+
81+
if self.guidance_rescale > 0.0:
82+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
83+
84+
return pred
85+
86+
@property
87+
def is_conditional(self) -> bool:
88+
return self._num_outputs_prepared == 0
89+
90+
@property
91+
def num_conditions(self) -> int:
92+
num_conditions = 1
93+
if self._is_tcfg_enabled():
94+
num_conditions += 1
95+
return num_conditions
96+
97+
def _is_tcfg_enabled(self) -> bool:
98+
if not self._enabled:
99+
return False
100+
101+
is_within_range = True
102+
if self._num_inference_steps is not None:
103+
skip_start_step = int(self._start * self._num_inference_steps)
104+
skip_stop_step = int(self._stop * self._num_inference_steps)
105+
is_within_range = skip_start_step <= self._step < skip_stop_step
106+
107+
is_close = False
108+
if self.use_original_formulation:
109+
is_close = math.isclose(self.guidance_scale, 0.0)
110+
else:
111+
is_close = math.isclose(self.guidance_scale, 1.0)
112+
113+
return is_within_range and not is_close
114+
115+
116+
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
117+
cond_dtype = pred_cond.dtype
118+
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
119+
preds = preds.flatten(2)
120+
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
121+
Vh_modified = Vh.clone()
122+
Vh_modified[:, 1] = 0
123+
124+
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
125+
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
126+
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
127+
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
128+
129+
pred = pred_cond if use_original_formulation else pred_uncond
130+
shift = pred_cond - pred_uncond
131+
pred = pred + guidance_scale * shift
132+
133+
return pred

src/diffusers/hooks/layer_skip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo
9292

9393
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
9494
if self.skip_attention_scores:
95-
if math.isclose(self.dropout, 1.0):
95+
if not math.isclose(self.dropout, 1.0):
9696
raise ValueError(
9797
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
9898
)

0 commit comments

Comments
 (0)