Skip to content

Commit b81bd78

Browse files
committed
support controlnet union
1 parent 9da8a9d commit b81bd78

File tree

6 files changed

+312
-144
lines changed

6 files changed

+312
-144
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
else:
133133
_import_structure["guiders"].extend(
134134
[
135+
"AdaptiveProjectedGuidance",
135136
"ClassifierFreeGuidance",
136137
"SkipLayerGuidance",
137138
]
@@ -721,6 +722,7 @@
721722
from .utils.dummy_pt_objects import * # noqa F403
722723
else:
723724
from .guiders import (
725+
AdaptiveProjectedGuidance,
724726
ClassifierFreeGuidance,
725727
SkipLayerGuidance,
726728
)

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
if is_torch_available():
21+
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
2122
from .classifier_free_guidance import ClassifierFreeGuidance
2223
from .skip_layer_guidance import SkipLayerGuidance
2324

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Union, Tuple, List
17+
18+
import torch
19+
20+
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
21+
22+
23+
class AdaptiveProjectedGuidance(BaseGuidance):
24+
"""
25+
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
26+
27+
Args:
28+
guidance_scale (`float`, defaults to `7.5`):
29+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
30+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
31+
deterioration of image quality.
32+
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
33+
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
34+
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
35+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
36+
guidance_rescale (`float`, defaults to `0.0`):
37+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
38+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
39+
Flawed](https://huggingface.co/papers/2305.08891).
40+
use_original_formulation (`bool`, defaults to `False`):
41+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
42+
we use the diffusers-native implementation that has been in the codebase for a long time. See
43+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
44+
start (`float`, defaults to `0.0`):
45+
The fraction of the total number of denoising steps after which guidance starts.
46+
stop (`float`, defaults to `1.0`):
47+
The fraction of the total number of denoising steps after which guidance stops.
48+
"""
49+
50+
_input_predictions = ["pred_cond", "pred_uncond"]
51+
52+
def __init__(
53+
self,
54+
guidance_scale: float = 7.5,
55+
adaptive_projected_guidance_momentum: Optional[float] = None,
56+
adaptive_projected_guidance_rescale: float = 15.0,
57+
eta: float = 1.0,
58+
guidance_rescale: float = 0.0,
59+
use_original_formulation: bool = False,
60+
start: float = 0.0,
61+
stop: float = 1.0,
62+
):
63+
super().__init__(start, stop)
64+
65+
self.guidance_scale = guidance_scale
66+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
67+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
68+
self.eta = eta
69+
self.guidance_rescale = guidance_rescale
70+
self.use_original_formulation = use_original_formulation
71+
self.momentum_buffer = None
72+
73+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
74+
if self._step == 0:
75+
if self.adaptive_projected_guidance_momentum is not None:
76+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
77+
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
78+
79+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
80+
self._num_outputs_prepared += 1
81+
if self._num_outputs_prepared > self.num_conditions:
82+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
83+
key = self._input_predictions[self._num_outputs_prepared - 1]
84+
self._preds[key] = pred
85+
86+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
87+
pred = None
88+
89+
if not self._is_cfg_enabled():
90+
pred = pred_cond
91+
else:
92+
pred = normalized_guidance(
93+
pred_cond,
94+
pred_uncond,
95+
self.guidance_scale,
96+
self.momentum_buffer,
97+
self.eta,
98+
self.adaptive_projected_guidance_rescale,
99+
self.use_original_formulation,
100+
)
101+
102+
if self.guidance_rescale > 0.0:
103+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
104+
105+
return pred
106+
107+
@property
108+
def is_conditional(self) -> bool:
109+
return self._num_outputs_prepared == 0
110+
111+
@property
112+
def num_conditions(self) -> int:
113+
num_conditions = 1
114+
if self._is_cfg_enabled():
115+
num_conditions += 1
116+
return num_conditions
117+
118+
def _is_cfg_enabled(self) -> bool:
119+
if not self._enabled:
120+
return False
121+
122+
is_within_range = True
123+
if self._num_inference_steps is not None:
124+
skip_start_step = int(self._start * self._num_inference_steps)
125+
skip_stop_step = int(self._stop * self._num_inference_steps)
126+
is_within_range = skip_start_step <= self._step < skip_stop_step
127+
128+
is_close = False
129+
if self.use_original_formulation:
130+
is_close = math.isclose(self.guidance_scale, 0.0)
131+
else:
132+
is_close = math.isclose(self.guidance_scale, 1.0)
133+
134+
return is_within_range and not is_close
135+
136+
137+
class MomentumBuffer:
138+
def __init__(self, momentum: float):
139+
self.momentum = momentum
140+
self.running_average = 0
141+
142+
def update(self, update_value: torch.Tensor):
143+
new_average = self.momentum * self.running_average
144+
self.running_average = update_value + new_average
145+
146+
147+
def normalized_guidance(
148+
pred_cond: torch.Tensor,
149+
pred_uncond: torch.Tensor,
150+
guidance_scale: float,
151+
momentum_buffer: Optional[MomentumBuffer] = None,
152+
eta: float = 1.0,
153+
norm_threshold: float = 0.0,
154+
use_original_formulation: bool = False,
155+
):
156+
diff = pred_cond - pred_uncond
157+
dim = [-i for i in range(1, len(diff.shape))]
158+
if momentum_buffer is not None:
159+
momentum_buffer.update(diff)
160+
diff = momentum_buffer.running_average
161+
if norm_threshold > 0:
162+
ones = torch.ones_like(diff)
163+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
164+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
165+
diff = diff * scale_factor
166+
v0, v1 = diff.double(), pred_cond.double()
167+
v1 = torch.nn.functional.normalize(v1, dim=dim)
168+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
169+
v0_orthogonal = v0 - v0_parallel
170+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
171+
normalized_update = diff_orthogonal + eta * diff_parallel
172+
pred = pred_cond if use_original_formulation else pred_uncond
173+
pred = pred + (guidance_scale - 1) * normalized_update
174+
return pred

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,17 @@ def num_conditions(self) -> int:
106106
def _is_cfg_enabled(self) -> bool:
107107
if not self._enabled:
108108
return False
109-
skip_start_step = int(self._start * self._num_inference_steps)
110-
skip_stop_step = int(self._stop * self._num_inference_steps)
111-
is_within_range = skip_start_step <= self._step < skip_stop_step
109+
110+
is_within_range = True
111+
if self._num_inference_steps is not None:
112+
skip_start_step = int(self._start * self._num_inference_steps)
113+
skip_stop_step = int(self._stop * self._num_inference_steps)
114+
is_within_range = skip_start_step <= self._step < skip_stop_step
115+
112116
is_close = False
113117
if self.use_original_formulation:
114118
is_close = math.isclose(self.guidance_scale, 0.0)
115119
else:
116120
is_close = math.isclose(self.guidance_scale, 1.0)
121+
117122
return is_within_range and not is_close

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,31 @@ def num_conditions(self) -> int:
206206
def _is_cfg_enabled(self) -> bool:
207207
if not self._enabled:
208208
return False
209-
skip_start_step = int(self._start * self._num_inference_steps)
210-
skip_stop_step = int(self._stop * self._num_inference_steps)
211-
is_within_range = skip_start_step <= self._step < skip_stop_step
209+
210+
is_within_range = True
211+
if self._num_inference_steps is not None:
212+
skip_start_step = int(self._start * self._num_inference_steps)
213+
skip_stop_step = int(self._stop * self._num_inference_steps)
214+
is_within_range = skip_start_step <= self._step < skip_stop_step
215+
212216
is_close = False
213217
if self.use_original_formulation:
214218
is_close = math.isclose(self.guidance_scale, 0.0)
215219
else:
216220
is_close = math.isclose(self.guidance_scale, 1.0)
221+
217222
return is_within_range and not is_close
218223

219224
def _is_slg_enabled(self) -> bool:
220225
if not self._enabled:
221226
return False
222-
skip_start_step = int(self._start * self._num_inference_steps)
223-
skip_stop_step = int(self._stop * self._num_inference_steps)
224-
is_within_range = skip_start_step < self._step < skip_stop_step
227+
228+
is_within_range = True
229+
if self._num_inference_steps is not None:
230+
skip_start_step = int(self._start * self._num_inference_steps)
231+
skip_stop_step = int(self._stop * self._num_inference_steps)
232+
is_within_range = skip_start_step < self._step < skip_stop_step
233+
225234
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
235+
226236
return is_within_range and not is_zero

0 commit comments

Comments
 (0)