|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import math |
| 16 | +from typing import Optional, Union, Tuple, List |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs |
| 21 | + |
| 22 | + |
| 23 | +class AdaptiveProjectedGuidance(BaseGuidance): |
| 24 | + """ |
| 25 | + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 |
| 26 | + |
| 27 | + Args: |
| 28 | + guidance_scale (`float`, defaults to `7.5`): |
| 29 | + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text |
| 30 | + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and |
| 31 | + deterioration of image quality. |
| 32 | + adaptive_projected_guidance_momentum (`float`, defaults to `None`): |
| 33 | + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. |
| 34 | + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): |
| 35 | + The rescale factor applied to the noise predictions. This is used to improve image quality and fix |
| 36 | + guidance_rescale (`float`, defaults to `0.0`): |
| 37 | + The rescale factor applied to the noise predictions. This is used to improve image quality and fix |
| 38 | + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are |
| 39 | + Flawed](https://huggingface.co/papers/2305.08891). |
| 40 | + use_original_formulation (`bool`, defaults to `False`): |
| 41 | + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, |
| 42 | + we use the diffusers-native implementation that has been in the codebase for a long time. See |
| 43 | + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. |
| 44 | + start (`float`, defaults to `0.0`): |
| 45 | + The fraction of the total number of denoising steps after which guidance starts. |
| 46 | + stop (`float`, defaults to `1.0`): |
| 47 | + The fraction of the total number of denoising steps after which guidance stops. |
| 48 | + """ |
| 49 | + |
| 50 | + _input_predictions = ["pred_cond", "pred_uncond"] |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + guidance_scale: float = 7.5, |
| 55 | + adaptive_projected_guidance_momentum: Optional[float] = None, |
| 56 | + adaptive_projected_guidance_rescale: float = 15.0, |
| 57 | + eta: float = 1.0, |
| 58 | + guidance_rescale: float = 0.0, |
| 59 | + use_original_formulation: bool = False, |
| 60 | + start: float = 0.0, |
| 61 | + stop: float = 1.0, |
| 62 | + ): |
| 63 | + super().__init__(start, stop) |
| 64 | + |
| 65 | + self.guidance_scale = guidance_scale |
| 66 | + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum |
| 67 | + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale |
| 68 | + self.eta = eta |
| 69 | + self.guidance_rescale = guidance_rescale |
| 70 | + self.use_original_formulation = use_original_formulation |
| 71 | + self.momentum_buffer = None |
| 72 | + |
| 73 | + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: |
| 74 | + if self._step == 0: |
| 75 | + if self.adaptive_projected_guidance_momentum is not None: |
| 76 | + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) |
| 77 | + return _default_prepare_inputs(denoiser, self.num_conditions, *args) |
| 78 | + |
| 79 | + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: |
| 80 | + self._num_outputs_prepared += 1 |
| 81 | + if self._num_outputs_prepared > self.num_conditions: |
| 82 | + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") |
| 83 | + key = self._input_predictions[self._num_outputs_prepared - 1] |
| 84 | + self._preds[key] = pred |
| 85 | + |
| 86 | + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 87 | + pred = None |
| 88 | + |
| 89 | + if not self._is_cfg_enabled(): |
| 90 | + pred = pred_cond |
| 91 | + else: |
| 92 | + pred = normalized_guidance( |
| 93 | + pred_cond, |
| 94 | + pred_uncond, |
| 95 | + self.guidance_scale, |
| 96 | + self.momentum_buffer, |
| 97 | + self.eta, |
| 98 | + self.adaptive_projected_guidance_rescale, |
| 99 | + self.use_original_formulation, |
| 100 | + ) |
| 101 | + |
| 102 | + if self.guidance_rescale > 0.0: |
| 103 | + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) |
| 104 | + |
| 105 | + return pred |
| 106 | + |
| 107 | + @property |
| 108 | + def is_conditional(self) -> bool: |
| 109 | + return self._num_outputs_prepared == 0 |
| 110 | + |
| 111 | + @property |
| 112 | + def num_conditions(self) -> int: |
| 113 | + num_conditions = 1 |
| 114 | + if self._is_cfg_enabled(): |
| 115 | + num_conditions += 1 |
| 116 | + return num_conditions |
| 117 | + |
| 118 | + def _is_cfg_enabled(self) -> bool: |
| 119 | + if not self._enabled: |
| 120 | + return False |
| 121 | + |
| 122 | + is_within_range = True |
| 123 | + if self._num_inference_steps is not None: |
| 124 | + skip_start_step = int(self._start * self._num_inference_steps) |
| 125 | + skip_stop_step = int(self._stop * self._num_inference_steps) |
| 126 | + is_within_range = skip_start_step <= self._step < skip_stop_step |
| 127 | + |
| 128 | + is_close = False |
| 129 | + if self.use_original_formulation: |
| 130 | + is_close = math.isclose(self.guidance_scale, 0.0) |
| 131 | + else: |
| 132 | + is_close = math.isclose(self.guidance_scale, 1.0) |
| 133 | + |
| 134 | + return is_within_range and not is_close |
| 135 | + |
| 136 | + |
| 137 | +class MomentumBuffer: |
| 138 | + def __init__(self, momentum: float): |
| 139 | + self.momentum = momentum |
| 140 | + self.running_average = 0 |
| 141 | + |
| 142 | + def update(self, update_value: torch.Tensor): |
| 143 | + new_average = self.momentum * self.running_average |
| 144 | + self.running_average = update_value + new_average |
| 145 | + |
| 146 | + |
| 147 | +def normalized_guidance( |
| 148 | + pred_cond: torch.Tensor, |
| 149 | + pred_uncond: torch.Tensor, |
| 150 | + guidance_scale: float, |
| 151 | + momentum_buffer: Optional[MomentumBuffer] = None, |
| 152 | + eta: float = 1.0, |
| 153 | + norm_threshold: float = 0.0, |
| 154 | + use_original_formulation: bool = False, |
| 155 | +): |
| 156 | + diff = pred_cond - pred_uncond |
| 157 | + dim = [-i for i in range(1, len(diff.shape))] |
| 158 | + if momentum_buffer is not None: |
| 159 | + momentum_buffer.update(diff) |
| 160 | + diff = momentum_buffer.running_average |
| 161 | + if norm_threshold > 0: |
| 162 | + ones = torch.ones_like(diff) |
| 163 | + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) |
| 164 | + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) |
| 165 | + diff = diff * scale_factor |
| 166 | + v0, v1 = diff.double(), pred_cond.double() |
| 167 | + v1 = torch.nn.functional.normalize(v1, dim=dim) |
| 168 | + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 |
| 169 | + v0_orthogonal = v0 - v0_parallel |
| 170 | + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) |
| 171 | + normalized_update = diff_orthogonal + eta * diff_parallel |
| 172 | + pred = pred_cond if use_original_formulation else pred_uncond |
| 173 | + pred = pred + (guidance_scale - 1) * normalized_update |
| 174 | + return pred |
0 commit comments