- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Implement Frequency-Decoupled Guidance (FDG) as a Guider #11976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
7d5901d
              fe824a8
              6949ece
              8c05d64
              33822e8
              565ce2a
              f608c5f
              34427b7
              c5070e0
              149c915
              0faa57a
              0a3f908
              259952a
              9c94aef
              4c379a4
              a4a829e
              d3dfb5f
              5d16521
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,317 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|  | ||
| import math | ||
| from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | ||
|  | ||
| import torch | ||
|  | ||
| from ..configuration_utils import register_to_config | ||
| from ..utils import is_kornia_available | ||
| from .guider_utils import BaseGuidance, rescale_noise_cfg | ||
|  | ||
|  | ||
| if TYPE_CHECKING: | ||
| from ..modular_pipelines.modular_pipeline import BlockState | ||
|  | ||
|  | ||
| _CAN_USE_KORNIA = is_kornia_available() | ||
|  | ||
|  | ||
| if _CAN_USE_KORNIA: | ||
| from kornia.geometry import pyrup as upsample_and_blur_func | ||
| from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func | ||
| else: | ||
| upsample_and_blur_func = None | ||
| build_laplacian_pyramid_func = None | ||
|  | ||
|  | ||
| def project(v0: torch.Tensor, v1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from | ||
| paper (Algorithm 2). | ||
| """ | ||
| # v0 shape: [B, ...] | ||
| # v1 shape: [B, ...] | ||
| dtype = v0.dtype | ||
| # Assume first dim is a batch dim and all other dims are channel or "spatial" dims | ||
| all_dims_but_first = list(range(1, len(v0.shape))) | ||
| v0, v1 = v0.double(), v1.double() | ||
| v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first) | ||
| v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1 | ||
| v0_orthogonal = v0 - v0_parallel | ||
| return v0_parallel.to(dtype), v0_orthogonal.to(dtype) | ||
|  | ||
|  | ||
| def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor: | ||
| """ | ||
| Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper | ||
| (Algorihtm 2). | ||
| """ | ||
| # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...] | ||
| img = pyramid[-1] | ||
| for i in range(len(pyramid) - 2, -1, -1): | ||
| img = upsample_and_blur_func(img) + pyramid[i] | ||
| return img | ||
|  | ||
|  | ||
| class FrequencyDecoupledGuidance(BaseGuidance): | ||
| """ | ||
| Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713 | ||
|  | ||
| FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation | ||
| quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both | ||
| conditional and unconditional data, and use a combination of the two during inference. (If you want more details | ||
| on how CFG works, you can check out the CFG guider.) | ||
|  | ||
| FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency | ||
| components using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in | ||
| frequency space separately for the low- and high-frequency components with different guidance scales. Finally, the | ||
| inverse frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for | ||
| images) to form the final FDG prediction. | ||
|  | ||
| For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample | ||
| diversity and realistic color composition, while using high guidance scales for high-frequency components enhances | ||
| sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) | ||
| for the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an | ||
| example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper). | ||
|  | ||
| As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen | ||
| paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in | ||
| theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] | ||
|  | ||
| The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the | ||
| paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. | ||
|  | ||
| Args: | ||
| guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`): | ||
| The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest | ||
| frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower | ||
| values allow for more freedom in generation. Higher values may lead to saturation and deterioration of | ||
| image quality. The FDG authors recommend using higher guidance scales for higher frequency components and | ||
| lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in | ||
| descending order). | ||
| guidance_rescale (`float` or `List[float]`, defaults to `0.0`): | ||
| The rescale factor applied to the noise predictions. This is used to improve image quality and fix | ||
| overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are | ||
| Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as | ||
| `guidance_scales`. | ||
| parallel_weights (`float` or `List[float]`, *optional*): | ||
| Optional weights for the parallel component of each frequency component of the projected CFG shift. If not | ||
| set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift | ||
| (that is, equal weights for the parallel and orthogonal components). If a list is supplied, it should be | ||
| the same length as `guidance_scales`. | ||
| use_original_formulation (`bool`, defaults to `False`): | ||
| Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, | ||
| we use the diffusers-native implementation that has been in the codebase for a long time. See | ||
| [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. | ||
| start (`float` or `List[float]`, defaults to `0.0`): | ||
| The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it | ||
| should be the same length as `guidance_scales`. | ||
| stop (`float` or `List[float]`, defaults to `1.0`): | ||
| The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it | ||
| should be the same length as `guidance_scales`. | ||
| guidance_rescale_space (`str`, defaults to `"data"`): | ||
| Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in | ||
| `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is | ||
| speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value | ||
| will be used; otherwise, per-frequency-level guidance rescale values will be used if available. | ||
| """ | ||
|  | ||
| _input_predictions = ["pred_cond", "pred_uncond"] | ||
|  | ||
| @register_to_config | ||
| def __init__( | ||
| self, | ||
| guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0], | ||
| guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0, | ||
| parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None, | ||
| use_original_formulation: bool = False, | ||
| start: Union[float, List[float], Tuple[float]] = 0.0, | ||
| stop: Union[float, List[float], Tuple[float]] = 1.0, | ||
| guidance_rescale_space: str = "data", | ||
| ): | ||
| if not _CAN_USE_KORNIA: | ||
| raise ImportError( | ||
| "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which" | ||
| "it depends is not available in the current environment." | ||
|          | ||
| ) | ||
|  | ||
| # Set start to earliest start for any freq component and stop to latest stop for any freq component | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! | ||
| min_start = start if isinstance(start, float) else min(start) | ||
| max_stop = stop if isinstance(stop, float) else max(stop) | ||
| super().__init__(min_start, max_stop) | ||
|  | ||
| self.guidance_scales = guidance_scales | ||
| self.levels = len(guidance_scales) | ||
|  | ||
| if isinstance(guidance_rescale, float): | ||
| self.guidance_rescale = [guidance_rescale] * self.levels | ||
| elif len(guidance_rescale) == self.levels: | ||
| self.guidance_rescale = guidance_rescale | ||
| else: | ||
| raise ValueError( | ||
| f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as " | ||
| f"`guidance_scales` ({len(self.guidance_scales)})" | ||
| ) | ||
| # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after | ||
| # transforming from frequency space back to data space) | ||
| if guidance_rescale_space not in ["data", "freq"]: | ||
| raise ValueError( | ||
| f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`." | ||
| ) | ||
| self.guidance_rescale_space = guidance_rescale_space | ||
|  | ||
| if parallel_weights is None: | ||
| # Use normal CFG shift (equal weights for parallel and orthogonal components) | ||
| self.parallel_weights = [1.0] * self.levels | ||
| elif isinstance(parallel_weights, float): | ||
| self.parallel_weights = [parallel_weights] * self.levels | ||
| elif len(parallel_weights) == self.levels: | ||
| self.parallel_weights = parallel_weights | ||
| else: | ||
| raise ValueError( | ||
| f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as " | ||
| f"`guidance_scales` ({len(self.guidance_scales)})" | ||
| ) | ||
|  | ||
| self.use_original_formulation = use_original_formulation | ||
|  | ||
| if isinstance(start, float): | ||
| self.guidance_start = [start] * self.levels | ||
| elif len(start) == self.levels: | ||
| self.guidance_start = start | ||
| else: | ||
| raise ValueError( | ||
| f"`start` has length {len(start)} but should have the same length as `guidance_scales` " | ||
| f"({len(self.guidance_scales)})" | ||
| ) | ||
| if isinstance(stop, float): | ||
| self.guidance_stop = [stop] * self.levels | ||
| elif len(stop) == self.levels: | ||
| self.guidance_stop = stop | ||
| else: | ||
| raise ValueError( | ||
| f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` " | ||
| f"({len(self.guidance_scales)})" | ||
| ) | ||
|  | ||
| def prepare_inputs( | ||
| self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None | ||
| ) -> List["BlockState"]: | ||
| if input_fields is None: | ||
| input_fields = self._input_fields | ||
|  | ||
| tuple_indices = [0] if self.num_conditions == 1 else [0, 1] | ||
| data_batches = [] | ||
| for i in range(self.num_conditions): | ||
| data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) | ||
| data_batches.append(data_batch) | ||
| return data_batches | ||
|  | ||
| def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| pred = None | ||
|  | ||
| if not self._is_fdg_enabled(): | ||
| pred = pred_cond | ||
| else: | ||
| # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions. | ||
| pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels) | ||
| pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels) | ||
|  | ||
| # From high frequencies to low frequencies, following the paper implementation | ||
| pred_guided_pyramid = [] | ||
| parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale) | ||
| for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters): | ||
| if self._is_fdg_enabled_for_level(level): | ||
| # Get the cond/uncond preds (in freq space) at the current frequency level | ||
| pred_cond_freq = pred_cond_pyramid[level] | ||
| pred_uncond_freq = pred_uncond_pyramid[level] | ||
|  | ||
| shift = pred_cond_freq - pred_uncond_freq | ||
|  | ||
| # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift) | ||
| if not math.isclose(parallel_weight, 1.0): | ||
| shift_parallel, shift_orthogonal = project(shift, pred_cond_freq) | ||
| shift = parallel_weight * shift_parallel + shift_orthogonal | ||
|  | ||
| # Apply CFG update for the current frequency level | ||
| pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq | ||
| pred = pred + guidance_scale * shift | ||
|  | ||
| if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0: | ||
| pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale) | ||
|  | ||
| # Add the current FDG guided level to the FDG prediction pyramid | ||
| pred_guided_pyramid.append(pred) | ||
| else: | ||
| # Add the current pred_cond_pyramid level as the "non-FDG" prediction | ||
| pred_guided_pyramid.append(pred_cond_freq) | ||
|  | ||
| # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform | ||
| pred = build_image_from_pyramid(pred_guided_pyramid) | ||
|  | ||
| # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value | ||
| # across all freq levels | ||
| if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: | ||
| pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) | ||
|  | ||
| return pred, {} | ||
|  | ||
| @property | ||
| def is_conditional(self) -> bool: | ||
| return self._count_prepared == 1 | ||
|  | ||
| @property | ||
| def num_conditions(self) -> int: | ||
| num_conditions = 1 | ||
| if self._is_fdg_enabled(): | ||
| num_conditions += 1 | ||
| return num_conditions | ||
|  | ||
| def _is_fdg_enabled(self) -> bool: | ||
| if not self._enabled: | ||
| return False | ||
|  | ||
| is_within_range = True | ||
| if self._num_inference_steps is not None: | ||
| skip_start_step = int(self._start * self._num_inference_steps) | ||
| skip_stop_step = int(self._stop * self._num_inference_steps) | ||
| is_within_range = skip_start_step <= self._step < skip_stop_step | ||
|  | ||
| is_close = False | ||
| if self.use_original_formulation: | ||
| is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales) | ||
| else: | ||
| is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales) | ||
|  | ||
| return is_within_range and not is_close | ||
|  | ||
| def _is_fdg_enabled_for_level(self, level: int) -> bool: | ||
| if not self._enabled: | ||
| return False | ||
|  | ||
| is_within_range = True | ||
| if self._num_inference_steps is not None: | ||
| skip_start_step = int(self.guidance_start[level] * self._num_inference_steps) | ||
| skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps) | ||
| is_within_range = skip_start_step <= self._step < skip_stop_step | ||
|  | ||
| is_close = False | ||
| if self.use_original_formulation: | ||
| is_close = math.isclose(self.guidance_scales[level], 0.0) | ||
| else: | ||
| is_close = math.isclose(self.guidance_scales[level], 1.0) | ||
|  | ||
| return is_within_range and not is_close | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 @Msadat97 Just curious whether this must be float64 and if you've tested the same with float32/lower-dtype and found it harmful? The operations here are very few, but fp64 is extremely slow and I wonder if this has any impact on the overall runtime (maybe negligible for images, but might be worth understanding for when number of tokens is larger, like in video models, and if the dtype here could be potentially user-configurable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
projectfunction is called whenparallel_weightsis set (and is not the default value of1.0), so the upcasted operations will only be performed sometimes.For now, I have added a
upcast_to_doubleargument which controls whetherprojectwill upcast to fp64.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some FDG samples which use a guidance scale of
[10.0, 5.0]and aparallel_weightsof1.5. The1.5value is somewhat arbitrary; @Msadat97, what is a reasonable range of values forparallel_weights?FDG with guidance scales
[10.0, 5.0],parallel_weights=1.5, upcast to double:FDG with guidance scales
[10.0, 5.0],parallel_weights=1.5, no upcast to double (with pipeline at fp16):In this case, the images look of similar quality with and without upcasting (with perhaps a slight reduction in quality for the non-upcasted version).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven’t specifically tested the FP32 projection part, but I’m not sure how much it affects performance in this case, as the operations involved are quite lightweight and the model still runs in FP16. I just felt it might be safer to use double for normalization and projection to improve numerical accuracy a bit.
Regarding the parallel component, I think it’s best to keep the weight below 1. A value like 0.5 should give a good balance. That said, we used 1 in most parts of the paper and treated it as optional.
@dg845 One last question: are you using the noise prediction (i.e., the model output) for FDG, or the x_0 prediction? Perhaps using x_0 might be better, since frequency decomposition is likely more meaningful there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, I am using the raw model output, whether that's$x_0$ -prediction, $\epsilon$ -prediction, $v$ -prediction, etc.
I believe it would be difficult to use the$x_0$ -prediction in the current modular pipeline design because this would require the FDG guider to know the internal state of the scheduler (in particular, the $x_0$  prediction). It would also be a little unnatural because in general the scheduler will execute after the guider, and the scheduler's $x_0$ -prediction internally, so in the FDG guider we would probably have to convert to $x_0$ -prediction, get the FDG prediction, and then convert back to the original 
prediction_typeand thebeta/sigma/etc. schedule to calculate thestepmethod usually expects a raw model output and will convert to anprediction_typeso that the FDG prediction can be used as expected in the scheduler. @yiyixuxu thoughts?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s how we implement FDG as well, and it’s similar to how Adaptive Projected Guidance (APG) was handled in the guiders. So I assume it should also be compatible with FDG?
P.S.: btw, this conversion is mainly useful for projection to be more meaningful. Otherwise, it's almost the same for all prediction types, since the frequency operations are linear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the$x_0$ -prediction, $\epsilon$ -prediction, etc., and it doesn't convert to $x_0$  internally.
AdaptiveProjectedGuidanceguider is implemented in the same way the FDG guider is currently implemented: theforwardmethod takes inpred_condandpred_uncondarguments but is agnostic as to whether these inputs areMy statement above that the FDG guider uses the raw model output is probably a little misleading, in the sense that this assumes that the calling code will supply the denoising model's output to the FDG guider. This is the case in e.g.
StableDiffusionXLLoopDenoiser:diffusers/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
Lines 232 to 243 in e46e139
but we could imagine that the calling$x_0$  before calling the guider, and then convert back after the guider executes. This would require the scheduler to be in the block's 
PipelineBlock(such asStableDiffusionXLLoopDenoiser) could instead do the conversion toexpected_components, and in this case we'd probably want the guider to expose a config likeshould_convert_to_sample_predictionand the scheduler to exposeconvert_to_sample_prediction/convert_to_prediction_typemethods.In general, I think it may make more sense to do something like$x_0$  conversion in the calling 
PipelineBlock, since in the current designPipelineBlocks can have access to the scheduler whereas the guider itself shouldn't be coupled to the scheduler.