Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
Expand Down Expand Up @@ -790,6 +791,10 @@ def step_callback(state: PipelineIntermediateState) -> None:

ext_manager.add_extension(PreviewExt(step_callback))

### cfg rescale
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

Expand Down
8 changes: 4 additions & 4 deletions invokeai/backend/stable_diffusion/diffusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)

# ext: override apply_cfg
ctx.noise_pred = self.apply_cfg(ctx)
# ext: override combine_noise_preds
ctx.noise_pred = self.combine_noise_preds(ctx)

# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)

# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
Expand All @@ -95,7 +95,7 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
return step_output

@staticmethod
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
POST_APPLY_CFG = "post_apply_cfg"
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
36 changes: 36 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext


class RescaleCFGExt(ExtensionBase):
def __init__(self, rescale_multiplier: float):
super().__init__()
self.rescale_multiplier = rescale_multiplier

@staticmethod
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)

x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final

@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
def rescale_noise_pred(self, ctx: DenoiseContext):
if self.rescale_multiplier > 0:
ctx.noise_pred = self._rescale_cfg(
ctx.noise_pred,
ctx.positive_noise_pred,
self.rescale_multiplier,
)
Loading