Skip to content

Commit e359268

Browse files
Merge pull request #3976 from victorca25/esrgan_fea
multiple trivial changes for "extras" models
2 parents bb21a4c + c9bb33d commit e359268

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

modules/esrgan_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def mod2normal(state_dict):
5050
def resrgan2normal(state_dict, nb=23):
5151
# this code is copied from https://github.com/victorca25/iNNfer
5252
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
53+
re8x = 0
5354
crt_net = {}
5455
items = []
5556
for k, v in state_dict.items():
@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
7576
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
7677
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
7778
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
78-
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
79-
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
80-
crt_net['model.10.weight'] = state_dict['conv_last.weight']
81-
crt_net['model.10.bias'] = state_dict['conv_last.bias']
79+
80+
if 'conv_up3.weight' in state_dict:
81+
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
82+
re8x = 3
83+
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
84+
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
85+
86+
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
87+
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
88+
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
89+
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
90+
8291
state_dict = crt_net
8392
return state_dict
8493

modules/modelloader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def cleanup_models():
8585
src_path = os.path.join(root_path, "ESRGAN")
8686
dest_path = os.path.join(models_path, "ESRGAN")
8787
move_files(src_path, dest_path)
88+
src_path = os.path.join(models_path, "BSRGAN")
89+
dest_path = os.path.join(models_path, "ESRGAN")
90+
move_files(src_path, dest_path, ".pth")
8891
src_path = os.path.join(root_path, "gfpgan")
8992
dest_path = os.path.join(models_path, "GFPGAN")
9093
move_files(src_path, dest_path)

modules/ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def create_ui(wrap_gradio_gpu_call):
10541054

10551055
with gr.Tabs(elem_id="extras_resize_mode"):
10561056
with gr.TabItem('Scale by'):
1057-
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
1057+
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
10581058
with gr.TabItem('Scale to'):
10591059
with gr.Group():
10601060
with gr.Row():

modules/upscaler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from modules import modelloader, shared
1111

1212
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
13+
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
1314
from modules.paths import models_path
1415

1516

@@ -57,7 +58,7 @@ def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
5758
dest_w = img.width * scale
5859
dest_h = img.height * scale
5960
for i in range(3):
60-
if img.width >= dest_w and img.height >= dest_h:
61+
if img.width > dest_w and img.height > dest_h:
6162
break
6263
img = self.do_upscale(img, selected_model)
6364
if img.width != dest_w or img.height != dest_h:
@@ -120,3 +121,17 @@ def __init__(self, dirname=None):
120121
self.name = "Lanczos"
121122
self.scalers = [UpscalerData("Lanczos", None, self)]
122123

124+
125+
class UpscalerNearest(Upscaler):
126+
scalers = []
127+
128+
def do_upscale(self, img, selected_model=None):
129+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
130+
131+
def load_model(self, _):
132+
pass
133+
134+
def __init__(self, dirname=None):
135+
super().__init__(False)
136+
self.name = "Nearest"
137+
self.scalers = [UpscalerData("Nearest", None, self)]

0 commit comments

Comments
 (0)