Skip to content

Commit 5003e5d

Browse files
Same changes as in other PRs, add check for running inpainting on inpaint model without source image
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 58f3072 commit 5003e5d

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def prepare_noise_and_latents(
718718
return seed, noise, latents
719719

720720
def invoke(self, context: InvocationContext) -> LatentsOutput:
721-
if os.environ.get("USE_MODULAR_DENOISE", False):
721+
if os.environ.get("USE_MODULAR_DENOISE", True):
722722
return self._new_invoke(context)
723723
else:
724724
return self._old_invoke(context)

invokeai/backend/stable_diffusion/extensions/inpaint.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Optional
44

55
import einops
66
import torch
@@ -20,27 +20,28 @@ def __init__(
2020
is_gradient_mask: bool,
2121
):
2222
super().__init__()
23-
self.mask = mask
24-
self.is_gradient_mask = is_gradient_mask
23+
self._mask = mask
24+
self._is_gradient_mask = is_gradient_mask
25+
self._noise: Optional[torch.Tensor] = None
2526

2627
@staticmethod
2728
def _is_normal_model(unet: UNet2DConditionModel):
2829
return unet.conv_in.in_channels == 4
2930

3031
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
3132
batch_size = latents.size(0)
32-
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
33+
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
3334
if t.dim() == 0:
3435
# some schedulers expect t to be one-dimensional.
3536
# TODO: file diffusers bug about inconsistency?
3637
t = einops.repeat(t, "-> batch", batch=batch_size)
3738
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
3839
# get very confused about what is happening from step to step when we do that.
39-
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t)
40+
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
4041
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
4142
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
4243
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
43-
if self.is_gradient_mask:
44+
if self._is_gradient_mask:
4445
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
4546
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
4647
masked_input = torch.where(mask_bool, latents, mask_latents)
@@ -53,11 +54,11 @@ def init_tensors(self, ctx: DenoiseContext):
5354
if not self._is_normal_model(ctx.unet):
5455
raise Exception("InpaintExt should be used only on normal models!")
5556

56-
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
57+
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
5758

58-
self.noise = ctx.inputs.noise
59-
if self.noise is None:
60-
self.noise = torch.randn(
59+
self._noise = ctx.inputs.noise
60+
if self._noise is None:
61+
self._noise = torch.randn(
6162
ctx.latents.shape,
6263
dtype=torch.float32,
6364
device="cpu",
@@ -85,7 +86,7 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext):
8586
# restore unmasked part after the last step is completed
8687
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
8788
def restore_unmasked(self, ctx: DenoiseContext):
88-
if self.is_gradient_mask:
89-
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
89+
if self._is_gradient_mask:
90+
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
9091
else:
91-
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)
92+
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)

invokeai/backend/stable_diffusion/extensions/inpaint_model.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ def __init__(
2020
is_gradient_mask: bool,
2121
):
2222
super().__init__()
23-
self.mask = mask
24-
self.masked_latents = masked_latents
25-
self.is_gradient_mask = is_gradient_mask
23+
if mask is not None and masked_latents is None:
24+
raise ValueError("Source image required for inpaint mask when inpaint model used!")
25+
26+
self._mask = mask
27+
self._masked_latents = masked_latents
28+
self._is_gradient_mask = is_gradient_mask
2629

2730
@staticmethod
2831
def _is_inpaint_model(unet: UNet2DConditionModel):
@@ -33,21 +36,21 @@ def init_tensors(self, ctx: DenoiseContext):
3336
if not self._is_inpaint_model(ctx.unet):
3437
raise Exception("InpaintModelExt should be used only on inpaint models!")
3538

36-
if self.mask is None:
37-
self.mask = torch.ones_like(ctx.latents[:1, :1])
38-
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
39+
if self._mask is None:
40+
self._mask = torch.ones_like(ctx.latents[:1, :1])
41+
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
3942

40-
if self.masked_latents is None:
41-
self.masked_latents = torch.zeros_like(ctx.latents[:1])
42-
self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
43+
if self._masked_latents is None:
44+
self._masked_latents = torch.zeros_like(ctx.latents[:1])
45+
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
4346

4447
# TODO: any ideas about order value?
4548
# do last so that other extensions works with normal latents
4649
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
4750
def append_inpaint_layers(self, ctx: DenoiseContext):
4851
batch_size = ctx.unet_kwargs.sample.shape[0]
49-
b_mask = torch.cat([self.mask] * batch_size)
50-
b_masked_latents = torch.cat([self.masked_latents] * batch_size)
52+
b_mask = torch.cat([self._mask] * batch_size)
53+
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
5154
ctx.unet_kwargs.sample = torch.cat(
5255
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
5356
dim=1,
@@ -57,10 +60,7 @@ def append_inpaint_layers(self, ctx: DenoiseContext):
5760
# restore unmasked part as inpaint model can change unmasked part slightly
5861
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
5962
def restore_unmasked(self, ctx: DenoiseContext):
60-
if self.mask is None:
61-
return
62-
63-
if self.is_gradient_mask:
64-
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
63+
if self._is_gradient_mask:
64+
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
6565
else:
66-
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)
66+
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)

0 commit comments

Comments
 (0)