Skip to content

Commit 056f06d

Browse files
committed
Reload VAE without reloading sd checkpoint
1 parent f8c6468 commit 056f06d

File tree

3 files changed

+98
-18
lines changed

3 files changed

+98
-18
lines changed

modules/sd_models.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd):
159159
return pl_sd
160160

161161

162-
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
163-
164162
def load_model_weights(model, checkpoint_info, vae_file="auto"):
165163
checkpoint_file = checkpoint_info.filename
166164
sd_model_hash = checkpoint_info.hash
167165

168166
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
169167

170-
checkpoint_key = (checkpoint_info, vae_file)
168+
checkpoint_key = checkpoint_info
171169

172170
if checkpoint_key not in checkpoints_loaded:
173171
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
@@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
190188
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
191189
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
192190

193-
sd_vae.load_vae(model, vae_file)
194-
model.first_stage_model.to(devices.dtype_vae)
195-
196191
if shared.opts.sd_checkpoint_cache > 0:
192+
# if PR #4035 were to get merged, restore base VAE first before caching
197193
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
198194
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
199195
checkpoints_loaded.popitem(last=False) # LRU
196+
200197
else:
201198
vae_name = sd_vae.get_filename(vae_file)
202199
print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
@@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
207204
model.sd_model_checkpoint = checkpoint_file
208205
model.sd_checkpoint_info = checkpoint_info
209206

207+
sd_vae.load_vae(model, vae_file)
208+
210209

211210
def load_model(checkpoint_info=None):
212211
from modules import lowvram, sd_hijack
@@ -254,14 +253,14 @@ def load_model(checkpoint_info=None):
254253
return sd_model
255254

256255

257-
def reload_model_weights(sd_model=None, info=None, force=False):
256+
def reload_model_weights(sd_model=None, info=None):
258257
from modules import lowvram, devices, sd_hijack
259258
checkpoint_info = info or select_checkpoint()
260259

261260
if not sd_model:
262261
sd_model = shared.sd_model
263262

264-
if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force:
263+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
265264
return
266265

267266
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):

modules/sd_vae.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,65 @@
11
import torch
22
import os
33
from collections import namedtuple
4-
from modules import shared, devices
4+
from modules import shared, devices, script_callbacks
55
from modules.paths import models_path
66
import glob
77

8+
89
model_dir = "Stable-diffusion"
910
model_path = os.path.abspath(os.path.join(models_path, model_dir))
1011
vae_dir = "VAE"
1112
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
1213

14+
1315
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
16+
17+
1418
default_vae_dict = {"auto": "auto", "None": "None"}
1519
default_vae_list = ["auto", "None"]
20+
21+
1622
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
1723
vae_dict = dict(default_vae_dict)
1824
vae_list = list(default_vae_list)
1925
first_load = True
2026

27+
28+
base_vae = None
29+
loaded_vae_file = None
30+
checkpoint_info = None
31+
32+
33+
def get_base_vae(model):
34+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
35+
return base_vae
36+
return None
37+
38+
39+
def store_base_vae(model):
40+
global base_vae, checkpoint_info
41+
if checkpoint_info != model.sd_checkpoint_info:
42+
base_vae = model.first_stage_model.state_dict().copy()
43+
checkpoint_info = model.sd_checkpoint_info
44+
45+
46+
def delete_base_vae():
47+
global base_vae, checkpoint_info
48+
base_vae = None
49+
checkpoint_info = None
50+
51+
52+
def restore_base_vae(model):
53+
global base_vae, checkpoint_info
54+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
55+
load_vae_dict(model, base_vae)
56+
delete_base_vae()
57+
58+
2159
def get_filename(filepath):
2260
return os.path.splitext(os.path.basename(filepath))[0]
2361

62+
2463
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
2564
global vae_dict, vae_list
2665
res = {}
@@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
4382
vae_dict.update(res)
4483
return vae_list
4584

85+
4686
def resolve_vae(checkpoint_file, vae_file="auto"):
4787
global first_load, vae_dict, vae_list
4888
# save_settings = False
@@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"):
96136

97137
return vae_file
98138

99-
def load_vae(model, vae_file):
100-
global first_load, vae_dict, vae_list
139+
140+
def load_vae(model, vae_file=None):
141+
global first_load, vae_dict, vae_list, loaded_vae_file
101142
# save_settings = False
102143

103144
if vae_file:
104145
print(f"Loading VAE weights from: {vae_file}")
105146
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
106147
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
107-
model.first_stage_model.load_state_dict(vae_dict_1)
148+
load_vae_dict(model, vae_dict_1)
108149

109-
# If vae used is not in dict, update it
110-
# It will be removed on refresh though
111-
if vae_file is not None:
150+
# If vae used is not in dict, update it
151+
# It will be removed on refresh though
112152
vae_opt = get_filename(vae_file)
113153
if vae_opt not in vae_dict:
114154
vae_dict[vae_opt] = vae_file
115155
vae_list.append(vae_opt)
116156

157+
loaded_vae_file = vae_file
158+
117159
"""
118160
# Save current VAE to VAE settings, maybe? will it work?
119161
if save_settings:
@@ -124,4 +166,45 @@ def load_vae(model, vae_file):
124166
"""
125167

126168
first_load = False
169+
170+
171+
# don't call this from outside
172+
def load_vae_dict(model, vae_dict_1=None):
173+
if vae_dict_1:
174+
store_base_vae(model)
175+
model.first_stage_model.load_state_dict(vae_dict_1)
176+
else:
177+
restore_base_vae()
127178
model.first_stage_model.to(devices.dtype_vae)
179+
180+
181+
def reload_vae_weights(sd_model=None, vae_file="auto"):
182+
from modules import lowvram, devices, sd_hijack
183+
184+
if not sd_model:
185+
sd_model = shared.sd_model
186+
187+
checkpoint_info = sd_model.sd_checkpoint_info
188+
checkpoint_file = checkpoint_info.filename
189+
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
190+
191+
if loaded_vae_file == vae_file:
192+
return
193+
194+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
195+
lowvram.send_everything_to_cpu()
196+
else:
197+
sd_model.to(devices.cpu)
198+
199+
sd_hijack.model_hijack.undo_hijack(sd_model)
200+
201+
load_vae(sd_model, vae_file)
202+
203+
sd_hijack.model_hijack.hijack(sd_model)
204+
script_callbacks.model_loaded_callback(sd_model)
205+
206+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
207+
sd_model.to(devices.device)
208+
209+
print(f"VAE Weights loaded.")
210+
return sd_model

webui.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def initialize():
8181
modules.sd_vae.refresh_vae_list()
8282
modules.sd_models.load_model()
8383
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
84-
# I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram,
85-
# so for now this reloads the whole model too
86-
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(force=True)), call=False)
84+
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
8785
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
8886
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
8987

0 commit comments

Comments
 (0)