Skip to content

Commit 0c4c1a8

Browse files
committed
cfg; slg; pag; sdxl without controlnet
1 parent d143851 commit 0c4c1a8

File tree

11 files changed

+1093
-73
lines changed

11 files changed

+1093
-73
lines changed

src/diffusers/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
_import_structure = {
3535
"configuration_utils": ["ConfigMixin"],
36+
"guiders": [],
3637
"hooks": [],
3738
"loaders": ["FromOriginalModelMixin"],
3839
"models": [],
@@ -129,12 +130,20 @@
129130
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
130131

131132
else:
133+
_import_structure["guiders"].extend(
134+
[
135+
"ClassifierFreeGuidance",
136+
"SkipLayerGuidance",
137+
]
138+
)
132139
_import_structure["hooks"].extend(
133140
[
134141
"FasterCacheConfig",
135142
"HookRegistry",
136143
"PyramidAttentionBroadcastConfig",
144+
"LayerSkipConfig",
137145
"apply_faster_cache",
146+
"apply_layer_skip",
138147
"apply_pyramid_attention_broadcast",
139148
]
140149
)
@@ -711,10 +720,16 @@
711720
except OptionalDependencyNotAvailable:
712721
from .utils.dummy_pt_objects import * # noqa F403
713722
else:
723+
from .guiders import (
724+
ClassifierFreeGuidance,
725+
SkipLayerGuidance,
726+
)
714727
from .hooks import (
715728
FasterCacheConfig,
716729
HookRegistry,
730+
LayerSkipConfig,
717731
PyramidAttentionBroadcastConfig,
732+
apply_layer_skip,
718733
apply_faster_cache,
719734
apply_pyramid_attention_broadcast,
720735
)

src/diffusers/guiders/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union
16+
17+
from ..utils import is_torch_available
18+
19+
20+
if is_torch_available():
21+
from .classifier_free_guidance import ClassifierFreeGuidance
22+
from .skip_layer_guidance import SkipLayerGuidance
23+
24+
GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional, Union, Tuple, List
17+
18+
import torch
19+
20+
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
21+
22+
23+
class ClassifierFreeGuidance(BaseGuidance):
24+
"""
25+
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
26+
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
27+
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
28+
inference. This allows the model to tradeoff between generation quality and sample diversity.
29+
The original paper proposes scaling and shifting the conditional distribution based on the difference between
30+
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
31+
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
32+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
33+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
34+
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
35+
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
36+
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
37+
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
38+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
39+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
40+
Args:
41+
guidance_scale (`float`, defaults to `7.5`):
42+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
43+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
44+
deterioration of image quality.
45+
guidance_rescale (`float`, defaults to `0.0`):
46+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
47+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
48+
Flawed](https://huggingface.co/papers/2305.08891).
49+
use_original_formulation (`bool`, defaults to `False`):
50+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
51+
we use the diffusers-native implementation that has been in the codebase for a long time. See
52+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
53+
start (`float`, defaults to `0.0`):
54+
The fraction of the total number of denoising steps after which guidance starts.
55+
stop (`float`, defaults to `1.0`):
56+
The fraction of the total number of denoising steps after which guidance stops.
57+
"""
58+
59+
_input_predictions = ["pred_cond", "pred_uncond"]
60+
61+
def __init__(
62+
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
63+
):
64+
super().__init__(start, stop)
65+
66+
self.guidance_scale = guidance_scale
67+
self.guidance_rescale = guidance_rescale
68+
self.use_original_formulation = use_original_formulation
69+
70+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
71+
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
72+
73+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
74+
self._num_outputs_prepared += 1
75+
if self._num_outputs_prepared > self.num_conditions:
76+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
77+
key = self._input_predictions[self._num_outputs_prepared - 1]
78+
self._preds[key] = pred
79+
80+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
81+
pred = None
82+
83+
if not self._is_cfg_enabled():
84+
pred = pred_cond
85+
else:
86+
shift = pred_cond - pred_uncond
87+
pred = pred_cond if self.use_original_formulation else pred_uncond
88+
pred = pred + self.guidance_scale * shift
89+
90+
if self.guidance_rescale > 0.0:
91+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
92+
93+
return pred
94+
95+
@property
96+
def num_conditions(self) -> int:
97+
num_conditions = 1
98+
if self._is_cfg_enabled():
99+
num_conditions += 1
100+
return num_conditions
101+
102+
def _is_cfg_enabled(self) -> bool:
103+
skip_start_step = int(self._start * self._num_inference_steps)
104+
skip_stop_step = int(self._stop * self._num_inference_steps)
105+
is_within_range = skip_start_step <= self._step < skip_stop_step
106+
is_close = False
107+
if self.use_original_formulation:
108+
is_close = math.isclose(self.guidance_scale, 0.0)
109+
else:
110+
is_close = math.isclose(self.guidance_scale, 1.0)
111+
return is_within_range and not is_close
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
16+
17+
import torch
18+
19+
from ..utils import get_logger
20+
21+
22+
if TYPE_CHECKING:
23+
from ..models.attention_processor import AttentionProcessor
24+
25+
26+
logger = get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
class BaseGuidance:
30+
r"""Base class providing the skeleton for implementing guidance techniques."""
31+
32+
_input_predictions = None
33+
34+
def __init__(self, start: float = 0.0, stop: float = 1.0):
35+
self._start = start
36+
self._stop = stop
37+
self._step: int = None
38+
self._num_inference_steps: int = None
39+
self._timestep: torch.LongTensor = None
40+
self._preds: Dict[str, torch.Tensor] = {}
41+
self._num_outputs_prepared: int = 0
42+
43+
if not (0.0 <= start < 1.0):
44+
raise ValueError(
45+
f"Expected `start` to be between 0.0 and 1.0, but got {start}."
46+
)
47+
if not (start <= stop <= 1.0):
48+
raise ValueError(
49+
f"Expected `stop` to be between {start} and 1.0, but got {stop}."
50+
)
51+
52+
if self._input_predictions is None or not isinstance(self._input_predictions, list):
53+
raise ValueError(
54+
"`_input_predictions` must be a list of required prediction names for the guidance technique."
55+
)
56+
57+
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
58+
self._step = step
59+
self._num_inference_steps = num_inference_steps
60+
self._timestep = timestep
61+
self._preds = {}
62+
self._num_outputs_prepared = 0
63+
64+
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
65+
raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.")
66+
67+
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
68+
raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.")
69+
70+
def __call__(self, **kwargs) -> Any:
71+
if len(kwargs) != self.num_conditions:
72+
raise ValueError(
73+
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
74+
)
75+
return self.forward(**kwargs)
76+
77+
def forward(self, *args, **kwargs) -> Any:
78+
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
79+
80+
@property
81+
def num_conditions(self) -> int:
82+
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
83+
84+
@property
85+
def outputs(self) -> Dict[str, torch.Tensor]:
86+
return self._preds
87+
88+
89+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
90+
r"""
91+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
92+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
93+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
94+
Args:
95+
noise_cfg (`torch.Tensor`):
96+
The predicted noise tensor for the guided diffusion process.
97+
noise_pred_text (`torch.Tensor`):
98+
The predicted noise tensor for the text-guided diffusion process.
99+
guidance_rescale (`float`, *optional*, defaults to 0.0):
100+
A rescale factor applied to the noise predictions.
101+
Returns:
102+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
103+
"""
104+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
105+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
106+
# rescale the results from guidance (fixes overexposure)
107+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
108+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
109+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
110+
return noise_cfg
111+
112+
113+
def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
114+
"""
115+
Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly
116+
prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the
117+
`GuidanceMixin` class.
118+
119+
Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:
120+
- The first element is the conditional input.
121+
- The second element is the unconditional input or None.
122+
123+
If only the conditional input is provided, it will be repeated for all batches.
124+
125+
If both conditional and unconditional inputs are provided, they are alternated as batches of data.
126+
"""
127+
list_of_inputs = []
128+
for arg in args:
129+
if arg is None or isinstance(arg, torch.Tensor):
130+
list_of_inputs.append([arg] * num_conditions)
131+
elif isinstance(arg, (tuple, list)):
132+
if len(arg) != 2:
133+
raise ValueError(
134+
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
135+
f"with the first element being the conditional input and the second element being the unconditional input or None."
136+
)
137+
if arg[1] is None:
138+
# Only conditioning inputs for all batches
139+
list_of_inputs.append([arg[0]] * num_conditions)
140+
else:
141+
# Alternating conditional and unconditional inputs as batches
142+
inputs = [arg[i % 2] for i in range(num_conditions)]
143+
list_of_inputs.append(inputs)
144+
else:
145+
raise ValueError(
146+
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
147+
)
148+
return tuple(list_of_inputs)

0 commit comments

Comments
 (0)