@@ -25,7 +25,7 @@ def __init__(
25
25
"""Initialize InpaintExt.
26
26
Args:
27
27
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28
- expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
28
+ expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
29
29
inpainted.
30
30
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
31
31
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
@@ -65,10 +65,10 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
65
65
mask_latents = einops .repeat (mask_latents , "b c h w -> (repeat b) c h w" , repeat = batch_size )
66
66
if self ._is_gradient_mask :
67
67
threshold = (t .item ()) / ctx .scheduler .config .num_train_timesteps
68
- mask_bool = mask > threshold
68
+ mask_bool = mask < 1 - threshold
69
69
masked_input = torch .where (mask_bool , latents , mask_latents )
70
70
else :
71
- masked_input = torch .lerp (mask_latents .to (dtype = latents .dtype ), latents , mask .to (dtype = latents .dtype ))
71
+ masked_input = torch .lerp (latents , mask_latents .to (dtype = latents .dtype ), mask .to (dtype = latents .dtype ))
72
72
return masked_input
73
73
74
74
@callback (ExtensionCallbackType .PRE_DENOISE_LOOP )
@@ -111,6 +111,6 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext):
111
111
@callback (ExtensionCallbackType .POST_DENOISE_LOOP )
112
112
def restore_unmasked (self , ctx : DenoiseContext ):
113
113
if self ._is_gradient_mask :
114
- ctx .latents = torch .where (self ._mask > 0 , ctx .latents , ctx .inputs .orig_latents )
114
+ ctx .latents = torch .where (self ._mask < 1 , ctx .latents , ctx .inputs .orig_latents )
115
115
else :
116
- ctx .latents = torch .lerp (ctx .inputs . orig_latents , ctx .latents , self ._mask )
116
+ ctx .latents = torch .lerp (ctx .latents , ctx .inputs . orig_latents , self ._mask )
0 commit comments