Skip to content

Commit e046e60

Browse files
committed
Add FreeU support to denoise
1 parent f9c61f1 commit e046e60

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6162
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6263
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6364
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
@@ -790,18 +791,22 @@ def step_callback(state: PipelineIntermediateState) -> None:
790791

791792
ext_manager.add_extension(PreviewExt(step_callback))
792793

794+
### freeu
795+
if self.unet.freeu_config:
796+
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
797+
793798
# ext: t2i/ip adapter
794799
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
795800

796801
unet_info = context.models.load(self.unet.unet)
797802
assert isinstance(unet_info.model, UNet2DConditionModel)
798803
with (
799-
unet_info.model_on_device() as (model_state_dict, unet),
804+
unet_info.model_on_device() as (cached_weights, unet),
800805
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
801806
# ext: controlnet
802807
ext_manager.patch_extensions(unet),
803808
# ext: freeu, seamless, ip adapter, lora
804-
ext_manager.patch_unet(model_state_dict, unet),
809+
ext_manager.patch_unet(unet, cached_weights),
805810
):
806811
sd_backend = StableDiffusionBackend(unet, scheduler)
807812
denoise_ctx.unet = unet
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Dict, Optional
5+
6+
from diffusers import UNet2DConditionModel
7+
8+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
9+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
10+
11+
if TYPE_CHECKING:
12+
from invokeai.app.shared.models import FreeUConfig
13+
14+
15+
class FreeUExt(ExtensionBase):
16+
def __init__(
17+
self,
18+
freeu_config: Optional[FreeUConfig],
19+
):
20+
super().__init__()
21+
self.freeu_config = freeu_config
22+
23+
@contextmanager
24+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
25+
did_apply_freeu = False
26+
try:
27+
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
28+
if self.freeu_config is not None:
29+
unet.enable_freeu(
30+
b1=self.freeu_config.b1,
31+
b2=self.freeu_config.b2,
32+
s1=self.freeu_config.s1,
33+
s2=self.freeu_config.s2,
34+
)
35+
did_apply_freeu = True
36+
37+
yield
38+
39+
finally:
40+
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
41+
if did_apply_freeu:
42+
unet.disable_freeu()

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext):
6363
yield None
6464

6565
@contextmanager
66-
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
66+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
6767
if self._is_canceled and self._is_canceled():
6868
raise CanceledException
6969

70-
# TODO: create logic in PR with extension which uses it
71-
yield None
70+
# TODO: create weight patch logic in PR with extension which uses it
71+
with ExitStack() as exit_stack:
72+
for ext in self._extensions:
73+
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
74+
75+
yield None

0 commit comments

Comments
 (0)