@@ -134,11 +134,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
134
134
# Dummy zero conditioning if we're not using inpainting model.
135
135
# Still takes up a bit of memory, but no encoder call.
136
136
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
137
- return torch .zeros (
138
- x .shape [0 ], 5 , 1 , 1 ,
139
- dtype = x .dtype ,
140
- device = x .device
141
- )
137
+ return x .new_zeros (x .shape [0 ], 5 , 1 , 1 )
142
138
143
139
height = height or self .height
144
140
width = width or self .width
@@ -156,11 +152,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
156
152
def img2img_image_conditioning (self , source_image , latent_image , image_mask = None ):
157
153
if self .sampler .conditioning_key not in {'hybrid' , 'concat' }:
158
154
# Dummy zero conditioning if we're not using inpainting model.
159
- return torch .zeros (
160
- latent_image .shape [0 ], 5 , 1 , 1 ,
161
- dtype = latent_image .dtype ,
162
- device = latent_image .device
163
- )
155
+ return latent_image .new_zeros (latent_image .shape [0 ], 5 , 1 , 1 )
164
156
165
157
# Handle the different mask inputs
166
158
if image_mask is not None :
@@ -174,11 +166,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask = No
174
166
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
175
167
conditioning_mask = torch .round (conditioning_mask )
176
168
else :
177
- conditioning_mask = torch . ones (1 , 1 , * source_image .shape [- 2 :])
169
+ conditioning_mask = source_image . new_ones (1 , 1 , * source_image .shape [- 2 :])
178
170
179
171
# Create another latent image, this time with a masked version of the original input.
180
172
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
181
- conditioning_mask = conditioning_mask .to (source_image .device )
173
+ conditioning_mask = conditioning_mask .to (source_image .device ). to ( source_image . dtype )
182
174
conditioning_image = torch .lerp (
183
175
source_image ,
184
176
source_image * (1.0 - conditioning_mask ),
@@ -674,6 +666,13 @@ def save_intermediate(image, index):
674
666
675
667
if opts .use_scale_latent_for_hires_fix :
676
668
samples = torch .nn .functional .interpolate (samples , size = (self .height // opt_f , self .width // opt_f ), mode = "bilinear" )
669
+
670
+ # Avoid making the inpainting conditioning unless necessary as
671
+ # this does need some extra compute to decode / encode the image again.
672
+ if getattr (self , "inpainting_mask_weight" , shared .opts .inpainting_mask_weight ) < 1.0 :
673
+ image_conditioning = self .img2img_image_conditioning (decode_first_stage (self .sd_model , samples ), samples )
674
+ else :
675
+ image_conditioning = self .txt2img_image_conditioning (samples )
677
676
678
677
for i in range (samples .shape [0 ]):
679
678
save_intermediate (samples , i )
@@ -700,14 +699,14 @@ def save_intermediate(image, index):
700
699
701
700
samples = self .sd_model .get_first_stage_encoding (self .sd_model .encode_first_stage (decoded_samples ))
702
701
702
+ image_conditioning = self .img2img_image_conditioning (decoded_samples , samples )
703
+
703
704
shared .state .nextjob ()
704
705
705
706
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
706
707
707
708
noise = create_random_tensors (samples .shape [1 :], seeds = seeds , subseeds = subseeds , subseed_strength = subseed_strength , seed_resize_from_h = self .seed_resize_from_h , seed_resize_from_w = self .seed_resize_from_w , p = self )
708
709
709
- image_conditioning = self .txt2img_image_conditioning (x )
710
-
711
710
# GC now before running the next img2img to prevent running out of memory
712
711
x = None
713
712
devices .torch_gc ()
0 commit comments