1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING , Optional
4
4
5
5
import einops
6
6
import torch
@@ -20,27 +20,28 @@ def __init__(
20
20
is_gradient_mask : bool ,
21
21
):
22
22
super ().__init__ ()
23
- self .mask = mask
24
- self .is_gradient_mask = is_gradient_mask
23
+ self ._mask = mask
24
+ self ._is_gradient_mask = is_gradient_mask
25
+ self ._noise : Optional [torch .Tensor ] = None
25
26
26
27
@staticmethod
27
28
def _is_normal_model (unet : UNet2DConditionModel ):
28
29
return unet .conv_in .in_channels == 4
29
30
30
31
def _apply_mask (self , ctx : DenoiseContext , latents : torch .Tensor , t : torch .Tensor ) -> torch .Tensor :
31
32
batch_size = latents .size (0 )
32
- mask = einops .repeat (self .mask , "b c h w -> (repeat b) c h w" , repeat = batch_size )
33
+ mask = einops .repeat (self ._mask , "b c h w -> (repeat b) c h w" , repeat = batch_size )
33
34
if t .dim () == 0 :
34
35
# some schedulers expect t to be one-dimensional.
35
36
# TODO: file diffusers bug about inconsistency?
36
37
t = einops .repeat (t , "-> batch" , batch = batch_size )
37
38
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
38
39
# get very confused about what is happening from step to step when we do that.
39
- mask_latents = ctx .scheduler .add_noise (ctx .inputs .orig_latents , self .noise , t )
40
+ mask_latents = ctx .scheduler .add_noise (ctx .inputs .orig_latents , self ._noise , t )
40
41
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
41
42
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
42
43
mask_latents = einops .repeat (mask_latents , "b c h w -> (repeat b) c h w" , repeat = batch_size )
43
- if self .is_gradient_mask :
44
+ if self ._is_gradient_mask :
44
45
threshhold = (t .item ()) / ctx .scheduler .config .num_train_timesteps
45
46
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
46
47
masked_input = torch .where (mask_bool , latents , mask_latents )
@@ -53,11 +54,11 @@ def init_tensors(self, ctx: DenoiseContext):
53
54
if not self ._is_normal_model (ctx .unet ):
54
55
raise Exception ("InpaintExt should be used only on normal models!" )
55
56
56
- self .mask = self .mask .to (device = ctx .latents .device , dtype = ctx .latents .dtype )
57
+ self ._mask = self ._mask .to (device = ctx .latents .device , dtype = ctx .latents .dtype )
57
58
58
- self .noise = ctx .inputs .noise
59
- if self .noise is None :
60
- self .noise = torch .randn (
59
+ self ._noise = ctx .inputs .noise
60
+ if self ._noise is None :
61
+ self ._noise = torch .randn (
61
62
ctx .latents .shape ,
62
63
dtype = torch .float32 ,
63
64
device = "cpu" ,
@@ -85,7 +86,7 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext):
85
86
# restore unmasked part after the last step is completed
86
87
@callback (ExtensionCallbackType .POST_DENOISE_LOOP )
87
88
def restore_unmasked (self , ctx : DenoiseContext ):
88
- if self .is_gradient_mask :
89
- ctx .latents = torch .where (self .mask > 0 , ctx .latents , ctx .inputs .orig_latents )
89
+ if self ._is_gradient_mask :
90
+ ctx .latents = torch .where (self ._mask > 0 , ctx .latents , ctx .inputs .orig_latents )
90
91
else :
91
- ctx .latents = torch .lerp (ctx .inputs .orig_latents , ctx .latents , self .mask )
92
+ ctx .latents = torch .lerp (ctx .inputs .orig_latents , ctx .latents , self ._mask )
0 commit comments