Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
Expand Down Expand Up @@ -790,18 +791,22 @@ def step_callback(state: PipelineIntermediateState) -> None:

ext_manager.add_extension(PreviewExt(step_callback))

### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/stable_diffusion/diffusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]

return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)

def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/stable_diffusion/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List, Optional

import torch
from diffusers import UNet2DConditionModel
Expand Down Expand Up @@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext):
yield None

@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None
35 changes: 35 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/freeu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional

import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase

if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig


class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
):
super().__init__()
self._freeu_config = freeu_config

@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)

try:
yield
finally:
unet.disable_freeu()
10 changes: 7 additions & 3 deletions invokeai/backend/stable_diffusion/extensions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext):
yield None

@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException

# TODO: create logic in PR with extension which uses it
yield None
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))

yield None