Skip to content

Commit 19c0024

Browse files
committed
Use non-inverted mask generally(except inpaint model handling)
1 parent c323a76 commit 19c0024

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def prep_inpaint_mask(
674674
else:
675675
masked_latents = torch.where(mask < 0.5, 0.0, latents)
676676

677-
return 1 - mask, masked_latents, self.denoise_mask.gradient
677+
return mask, masked_latents, self.denoise_mask.gradient
678678

679679
@staticmethod
680680
def prepare_noise_and_latents(
@@ -830,6 +830,8 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
830830
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
831831

832832
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
833+
if mask is not None:
834+
mask = 1 - mask
833835

834836
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
835837
# below. Investigate whether this is appropriate.

invokeai/backend/stable_diffusion/extensions/inpaint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
"""Initialize InpaintExt.
2626
Args:
2727
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28-
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
28+
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
2929
inpainted.
3030
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
3131
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
@@ -65,10 +65,10 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
6565
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
6666
if self._is_gradient_mask:
6767
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
68-
mask_bool = mask > threshold
68+
mask_bool = mask < 1 - threshold
6969
masked_input = torch.where(mask_bool, latents, mask_latents)
7070
else:
71-
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
71+
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
7272
return masked_input
7373

7474
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
@@ -111,6 +111,6 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext):
111111
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
112112
def restore_unmasked(self, ctx: DenoiseContext):
113113
if self._is_gradient_mask:
114-
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
114+
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
115115
else:
116-
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
116+
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)

invokeai/backend/stable_diffusion/extensions/inpaint_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
"""Initialize InpaintModelExt.
2626
Args:
2727
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28-
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
28+
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
2929
inpainted.
3030
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
3131
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
@@ -37,7 +37,10 @@ def __init__(
3737
if mask is not None and masked_latents is None:
3838
raise ValueError("Source image required for inpaint mask when inpaint model used!")
3939

40-
self._mask = mask
40+
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
41+
self._mask = None
42+
if mask is not None:
43+
self._mask = 1 - mask
4144
self._masked_latents = masked_latents
4245
self._is_gradient_mask = is_gradient_mask
4346

0 commit comments

Comments
 (0)