Skip to content

Commit 9e7b470

Browse files
committed
Handle inpaint models
1 parent f9c61f1 commit 9e7b470

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
6162
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6263
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6364
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
@@ -790,6 +791,12 @@ def step_callback(state: PipelineIntermediateState) -> None:
790791

791792
ext_manager.add_extension(PreviewExt(step_callback))
792793

794+
### inpaint
795+
# TODO: add inpainting on normal model
796+
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
797+
if unet_config.variant == "inpaint": # ModelVariantType.Inpaint:
798+
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
799+
793800
# ext: t2i/ip adapter
794801
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
795802

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional
4+
5+
import torch
6+
from diffusers import UNet2DConditionModel
7+
8+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
9+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
10+
11+
if TYPE_CHECKING:
12+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
13+
14+
15+
class InpaintModelExt(ExtensionBase):
16+
def __init__(
17+
self,
18+
mask: Optional[torch.Tensor],
19+
masked_latents: Optional[torch.Tensor],
20+
is_gradient_mask: bool,
21+
):
22+
super().__init__()
23+
self.mask = mask
24+
self.masked_latents = masked_latents
25+
self.is_gradient_mask = is_gradient_mask
26+
27+
@staticmethod
28+
def _is_inpaint_model(unet: UNet2DConditionModel):
29+
return unet.conv_in.in_channels == 9
30+
31+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
32+
def init_tensors(self, ctx: DenoiseContext):
33+
if not self._is_inpaint_model(ctx.unet):
34+
raise Exception("InpaintModelExt should be used only on inpaint model!")
35+
36+
if self.mask is None:
37+
self.mask = torch.ones_like(ctx.latents[:1, :1])
38+
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
39+
40+
if self.masked_latents is None:
41+
self.masked_latents = torch.zeros_like(ctx.latents[:1])
42+
self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
43+
44+
# TODO: any ideas about order value?
45+
# do last so that other extensions works with normal latents
46+
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
47+
def append_inpaint_layers(self, ctx: DenoiseContext):
48+
batch_size = ctx.unet_kwargs.sample.shape[0]
49+
b_mask = torch.cat([self.mask] * batch_size)
50+
b_masked_latents = torch.cat([self.masked_latents] * batch_size)
51+
ctx.unet_kwargs.sample = torch.cat(
52+
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
53+
dim=1,
54+
)
55+
56+
# TODO: should here be used order?
57+
# restore unmasked part as inpaint model can change unmasked part slightly
58+
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
59+
def restore_unmasked(self, ctx: DenoiseContext):
60+
if self.mask is None:
61+
return
62+
63+
if self.is_gradient_mask:
64+
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
65+
else:
66+
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)

0 commit comments

Comments
 (0)