Skip to content

Commit 6e2ce4e

Browse files
Added image conditioning to latent upscale.
Only comuted if the mask weight is not 1.0 to avoid extra memory. Also includes some code cleanup.
1 parent 44ab954 commit 6e2ce4e

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

modules/processing.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
134134
# Dummy zero conditioning if we're not using inpainting model.
135135
# Still takes up a bit of memory, but no encoder call.
136136
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
137-
return torch.zeros(
138-
x.shape[0], 5, 1, 1,
139-
dtype=x.dtype,
140-
device=x.device
141-
)
137+
return x.new_zeros(x.shape[0], 5, 1, 1)
142138

143139
height = height or self.height
144140
width = width or self.width
@@ -156,11 +152,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
156152
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
157153
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
158154
# Dummy zero conditioning if we're not using inpainting model.
159-
return torch.zeros(
160-
latent_image.shape[0], 5, 1, 1,
161-
dtype=latent_image.dtype,
162-
device=latent_image.device
163-
)
155+
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
164156

165157
# Handle the different mask inputs
166158
if image_mask is not None:
@@ -174,11 +166,10 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask = No
174166
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
175167
conditioning_mask = torch.round(conditioning_mask)
176168
else:
177-
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
169+
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
178170

179171
# Create another latent image, this time with a masked version of the original input.
180172
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
181-
conditioning_mask = conditioning_mask.to(source_image.device)
182173
conditioning_image = torch.lerp(
183174
source_image,
184175
source_image * (1.0 - conditioning_mask),
@@ -653,7 +644,13 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
653644

654645
if opts.use_scale_latent_for_hires_fix:
655646
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
656-
image_conditioning = self.txt2img_image_conditioning(samples)
647+
648+
# Avoid making the inpainting conditioning unless necessary as
649+
# this does need some extra compute to decode / encode the image again.
650+
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
651+
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
652+
else:
653+
image_conditioning = self.txt2img_image_conditioning(samples)
657654

658655
else:
659656
decoded_samples = decode_first_stage(self.sd_model, samples)
@@ -675,11 +672,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
675672

676673
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
677674

678-
image_conditioning = self.img2img_image_conditioning(
679-
decoded_samples,
680-
samples,
681-
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
682-
)
675+
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
683676

684677
shared.state.nextjob()
685678

0 commit comments

Comments
 (0)