Skip to content

Commit 58f3072

Browse files
committed
Handle inpainting on normal models
1 parent 9e7b470 commit 58f3072

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from invokeai.app.util.controlnet_utils import prepare_control_image
3838
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
3939
from invokeai.backend.lora import LoRAModelRaw
40-
from invokeai.backend.model_manager import BaseModelType
40+
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
4141
from invokeai.backend.model_patcher import ModelPatcher
4242
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
4343
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
6162
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
6263
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6364
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@@ -792,10 +793,11 @@ def step_callback(state: PipelineIntermediateState) -> None:
792793
ext_manager.add_extension(PreviewExt(step_callback))
793794

794795
### inpaint
795-
# TODO: add inpainting on normal model
796796
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
797-
if unet_config.variant == "inpaint": # ModelVariantType.Inpaint:
797+
if unet_config.variant == ModelVariantType.Inpaint:
798798
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
799+
elif mask is not None:
800+
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
799801

800802
# ext: t2i/ip adapter
801803
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import einops
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
9+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
10+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
11+
12+
if TYPE_CHECKING:
13+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
14+
15+
16+
class InpaintExt(ExtensionBase):
17+
def __init__(
18+
self,
19+
mask: torch.Tensor,
20+
is_gradient_mask: bool,
21+
):
22+
super().__init__()
23+
self.mask = mask
24+
self.is_gradient_mask = is_gradient_mask
25+
26+
@staticmethod
27+
def _is_normal_model(unet: UNet2DConditionModel):
28+
return unet.conv_in.in_channels == 4
29+
30+
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
31+
batch_size = latents.size(0)
32+
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
33+
if t.dim() == 0:
34+
# some schedulers expect t to be one-dimensional.
35+
# TODO: file diffusers bug about inconsistency?
36+
t = einops.repeat(t, "-> batch", batch=batch_size)
37+
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
38+
# get very confused about what is happening from step to step when we do that.
39+
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t)
40+
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
41+
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
42+
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
43+
if self.is_gradient_mask:
44+
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
45+
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
46+
masked_input = torch.where(mask_bool, latents, mask_latents)
47+
else:
48+
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
49+
return masked_input
50+
51+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
52+
def init_tensors(self, ctx: DenoiseContext):
53+
if not self._is_normal_model(ctx.unet):
54+
raise Exception("InpaintExt should be used only on normal models!")
55+
56+
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
57+
58+
self.noise = ctx.inputs.noise
59+
if self.noise is None:
60+
self.noise = torch.randn(
61+
ctx.latents.shape,
62+
dtype=torch.float32,
63+
device="cpu",
64+
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
65+
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
66+
67+
# TODO: order value
68+
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
69+
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
70+
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
71+
72+
# TODO: order value
73+
# TODO: redo this with preview events rewrite
74+
@callback(ExtensionCallbackType.POST_STEP, order=-100)
75+
def apply_mask_to_step_output(self, ctx: DenoiseContext):
76+
timestep = ctx.scheduler.timesteps[-1]
77+
if hasattr(ctx.step_output, "denoised"):
78+
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
79+
elif hasattr(ctx.step_output, "pred_original_sample"):
80+
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
81+
else:
82+
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
83+
84+
# TODO: should here be used order?
85+
# restore unmasked part after the last step is completed
86+
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
87+
def restore_unmasked(self, ctx: DenoiseContext):
88+
if self.is_gradient_mask:
89+
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
90+
else:
91+
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)

invokeai/backend/stable_diffusion/extensions/inpaint_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _is_inpaint_model(unet: UNet2DConditionModel):
3131
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
3232
def init_tensors(self, ctx: DenoiseContext):
3333
if not self._is_inpaint_model(ctx.unet):
34-
raise Exception("InpaintModelExt should be used only on inpaint model!")
34+
raise Exception("InpaintModelExt should be used only on inpaint models!")
3535

3636
if self.mask is None:
3737
self.mask = torch.ones_like(ctx.latents[:1, :1])

0 commit comments

Comments
 (0)