@@ -134,11 +134,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
134
134
# Dummy zero conditioning if we're not using inpainting model.
135
135
# Still takes up a bit of memory, but no encoder call.
136
136
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
137
- return torch .zeros (
138
- x .shape [0 ], 5 , 1 , 1 ,
139
- dtype = x .dtype ,
140
- device = x .device
141
- )
137
+ return x .new_zeros (x .shape [0 ], 5 , 1 , 1 )
142
138
143
139
height = height or self .height
144
140
width = width or self .width
@@ -156,11 +152,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
156
152
def img2img_image_conditioning (self , source_image , latent_image , image_mask = None ):
157
153
if self .sampler .conditioning_key not in {'hybrid' , 'concat' }:
158
154
# Dummy zero conditioning if we're not using inpainting model.
159
- return torch .zeros (
160
- latent_image .shape [0 ], 5 , 1 , 1 ,
161
- dtype = latent_image .dtype ,
162
- device = latent_image .device
163
- )
155
+ return latent_image .new_zeros (latent_image .shape [0 ], 5 , 1 , 1 )
164
156
165
157
# Handle the different mask inputs
166
158
if image_mask is not None :
@@ -174,11 +166,10 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask = No
174
166
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
175
167
conditioning_mask = torch .round (conditioning_mask )
176
168
else :
177
- conditioning_mask = torch . ones (1 , 1 , * source_image .shape [- 2 :])
169
+ conditioning_mask = source_image . new_ones (1 , 1 , * source_image .shape [- 2 :])
178
170
179
171
# Create another latent image, this time with a masked version of the original input.
180
172
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
181
- conditioning_mask = conditioning_mask .to (source_image .device )
182
173
conditioning_image = torch .lerp (
183
174
source_image ,
184
175
source_image * (1.0 - conditioning_mask ),
@@ -653,7 +644,13 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
653
644
654
645
if opts .use_scale_latent_for_hires_fix :
655
646
samples = torch .nn .functional .interpolate (samples , size = (self .height // opt_f , self .width // opt_f ), mode = "bilinear" )
656
- image_conditioning = self .txt2img_image_conditioning (samples )
647
+
648
+ # Avoid making the inpainting conditioning unless necessary as
649
+ # this does need some extra compute to decode / encode the image again.
650
+ if getattr (self , "inpainting_mask_weight" , shared .opts .inpainting_mask_weight ) < 1.0 :
651
+ image_conditioning = self .img2img_image_conditioning (decode_first_stage (self .sd_model , samples ), samples )
652
+ else :
653
+ image_conditioning = self .txt2img_image_conditioning (samples )
657
654
658
655
else :
659
656
decoded_samples = decode_first_stage (self .sd_model , samples )
@@ -675,11 +672,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
675
672
676
673
samples = self .sd_model .get_first_stage_encoding (self .sd_model .encode_first_stage (decoded_samples ))
677
674
678
- image_conditioning = self .img2img_image_conditioning (
679
- decoded_samples ,
680
- samples ,
681
- decoded_samples .new_ones (decoded_samples .shape [0 ], 1 , decoded_samples .shape [2 ], decoded_samples .shape [3 ])
682
- )
675
+ image_conditioning = self .img2img_image_conditioning (decoded_samples , samples )
683
676
684
677
shared .state .nextjob ()
685
678
0 commit comments