Skip to content

Commit ea5b90b

Browse files
Merge pull request #4371 from hotdogee/master
Fixes #800 #1562 #2075 #2304 #2931 LDSR upscaler producing black bars
2 parents 2f47724 + 6603f63 commit ea5b90b

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

modules/ldsr_model_arch.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,21 @@ def super_resolution(self, image, steps=100, target_scale=2, half_attention=Fals
101101
down_sample_rate = target_scale / 4
102102
wd = width_og * down_sample_rate
103103
hd = height_og * down_sample_rate
104-
width_downsampled_pre = int(wd)
105-
height_downsampled_pre = int(hd)
104+
width_downsampled_pre = int(np.ceil(wd))
105+
height_downsampled_pre = int(np.ceil(hd))
106106

107107
if down_sample_rate != 1:
108108
print(
109109
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
110110
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
111111
else:
112112
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
113-
logs = self.run(model["model"], im_og, diffusion_steps, eta)
113+
114+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
115+
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
116+
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
117+
118+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
114119

115120
sample = logs["sample"]
116121
sample = sample.detach().cpu()
@@ -120,6 +125,9 @@ def super_resolution(self, image, steps=100, target_scale=2, half_attention=Fals
120125
sample = np.transpose(sample, (0, 2, 3, 1))
121126
a = Image.fromarray(sample[0])
122127

128+
# remove padding
129+
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
130+
123131
del model
124132
gc.collect()
125133
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)