Skip to content

Commit 2cf3d2a

Browse files
Merge pull request #3923 from random-thoughtss/master
Fix weighted mask for highres fix
2 parents 3f0f328 + 243253f commit 2cf3d2a

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

modules/masking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
4949
ratio_processing = processing_width / processing_height
5050

5151
if ratio_crop_region > ratio_processing:
52-
desired_height = (x2 - x1) * ratio_processing
52+
desired_height = (x2 - x1) / ratio_processing
5353
desired_height_diff = int(desired_height - (y2-y1))
5454
y1 -= desired_height_diff//2
5555
y2 += desired_height_diff - desired_height_diff//2

modules/processing.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
134134
# Dummy zero conditioning if we're not using inpainting model.
135135
# Still takes up a bit of memory, but no encoder call.
136136
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
137-
return torch.zeros(
138-
x.shape[0], 5, 1, 1,
139-
dtype=x.dtype,
140-
device=x.device
141-
)
137+
return x.new_zeros(x.shape[0], 5, 1, 1)
142138

143139
height = height or self.height
144140
width = width or self.width
@@ -156,11 +152,7 @@ def txt2img_image_conditioning(self, x, width=None, height=None):
156152
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
157153
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
158154
# Dummy zero conditioning if we're not using inpainting model.
159-
return torch.zeros(
160-
latent_image.shape[0], 5, 1, 1,
161-
dtype=latent_image.dtype,
162-
device=latent_image.device
163-
)
155+
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
164156

165157
# Handle the different mask inputs
166158
if image_mask is not None:
@@ -174,11 +166,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask = No
174166
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
175167
conditioning_mask = torch.round(conditioning_mask)
176168
else:
177-
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
169+
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
178170

179171
# Create another latent image, this time with a masked version of the original input.
180172
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
181-
conditioning_mask = conditioning_mask.to(source_image.device)
173+
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
182174
conditioning_image = torch.lerp(
183175
source_image,
184176
source_image * (1.0 - conditioning_mask),
@@ -674,6 +666,13 @@ def save_intermediate(image, index):
674666

675667
if opts.use_scale_latent_for_hires_fix:
676668
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
669+
670+
# Avoid making the inpainting conditioning unless necessary as
671+
# this does need some extra compute to decode / encode the image again.
672+
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
673+
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
674+
else:
675+
image_conditioning = self.txt2img_image_conditioning(samples)
677676

678677
for i in range(samples.shape[0]):
679678
save_intermediate(samples, i)
@@ -700,14 +699,14 @@ def save_intermediate(image, index):
700699

701700
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
702701

702+
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
703+
703704
shared.state.nextjob()
704705

705706
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
706707

707708
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
708709

709-
image_conditioning = self.txt2img_image_conditioning(x)
710-
711710
# GC now before running the next img2img to prevent running out of memory
712711
x = None
713712
devices.torch_gc()

0 commit comments

Comments
 (0)