Skip to content

Commit 675b51e

Browse files
Merge pull request #3986 from R-N/vae-picker
VAE Selector
2 parents e359268 + a5409a6 commit 675b51e

File tree

6 files changed

+233
-28
lines changed

6 files changed

+233
-28
lines changed

models/VAE/Put VAE here.txt

Whitespace-only changes.

modules/sd_models.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ldm.util import instantiate_from_config
1111

12-
from modules import shared, modelloader, devices, script_callbacks
12+
from modules import shared, modelloader, devices, script_callbacks, sd_vae
1313
from modules.paths import models_path
1414
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
1515

@@ -159,14 +159,15 @@ def get_state_dict_from_checkpoint(pl_sd):
159159
return pl_sd
160160

161161

162-
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
163-
164-
165-
def load_model_weights(model, checkpoint_info):
162+
def load_model_weights(model, checkpoint_info, vae_file="auto"):
166163
checkpoint_file = checkpoint_info.filename
167164
sd_model_hash = checkpoint_info.hash
168165

169-
if checkpoint_info not in checkpoints_loaded:
166+
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
167+
168+
checkpoint_key = checkpoint_info
169+
170+
if checkpoint_key not in checkpoints_loaded:
170171
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
171172

172173
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -187,32 +188,24 @@ def load_model_weights(model, checkpoint_info):
187188
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
188189
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
189190

190-
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
191-
192-
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
193-
vae_file = shared.cmd_opts.vae_path
194-
195-
if os.path.exists(vae_file):
196-
print(f"Loading VAE weights from: {vae_file}")
197-
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
198-
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
199-
model.first_stage_model.load_state_dict(vae_dict)
200-
201-
model.first_stage_model.to(devices.dtype_vae)
202-
203191
if shared.opts.sd_checkpoint_cache > 0:
204-
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
192+
# if PR #4035 were to get merged, restore base VAE first before caching
193+
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
205194
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
206195
checkpoints_loaded.popitem(last=False) # LRU
196+
207197
else:
208-
print(f"Loading weights [{sd_model_hash}] from cache")
209-
checkpoints_loaded.move_to_end(checkpoint_info)
210-
model.load_state_dict(checkpoints_loaded[checkpoint_info])
198+
vae_name = sd_vae.get_filename(vae_file)
199+
print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
200+
checkpoints_loaded.move_to_end(checkpoint_key)
201+
model.load_state_dict(checkpoints_loaded[checkpoint_key])
211202

212203
model.sd_model_hash = sd_model_hash
213204
model.sd_model_checkpoint = checkpoint_file
214205
model.sd_checkpoint_info = checkpoint_info
215206

207+
sd_vae.load_vae(model, vae_file)
208+
216209

217210
def load_model(checkpoint_info=None):
218211
from modules import lowvram, sd_hijack
@@ -263,7 +256,7 @@ def load_model(checkpoint_info=None):
263256
def reload_model_weights(sd_model=None, info=None):
264257
from modules import lowvram, devices, sd_hijack
265258
checkpoint_info = info or select_checkpoint()
266-
259+
267260
if not sd_model:
268261
sd_model = shared.sd_model
269262

modules/sd_vae.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import torch
2+
import os
3+
from collections import namedtuple
4+
from modules import shared, devices, script_callbacks
5+
from modules.paths import models_path
6+
import glob
7+
8+
9+
model_dir = "Stable-diffusion"
10+
model_path = os.path.abspath(os.path.join(models_path, model_dir))
11+
vae_dir = "VAE"
12+
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
13+
14+
15+
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
16+
17+
18+
default_vae_dict = {"auto": "auto", "None": "None"}
19+
default_vae_list = ["auto", "None"]
20+
21+
22+
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
23+
vae_dict = dict(default_vae_dict)
24+
vae_list = list(default_vae_list)
25+
first_load = True
26+
27+
28+
base_vae = None
29+
loaded_vae_file = None
30+
checkpoint_info = None
31+
32+
33+
def get_base_vae(model):
34+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
35+
return base_vae
36+
return None
37+
38+
39+
def store_base_vae(model):
40+
global base_vae, checkpoint_info
41+
if checkpoint_info != model.sd_checkpoint_info:
42+
base_vae = model.first_stage_model.state_dict().copy()
43+
checkpoint_info = model.sd_checkpoint_info
44+
45+
46+
def delete_base_vae():
47+
global base_vae, checkpoint_info
48+
base_vae = None
49+
checkpoint_info = None
50+
51+
52+
def restore_base_vae(model):
53+
global base_vae, checkpoint_info
54+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
55+
load_vae_dict(model, base_vae)
56+
delete_base_vae()
57+
58+
59+
def get_filename(filepath):
60+
return os.path.splitext(os.path.basename(filepath))[0]
61+
62+
63+
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
64+
global vae_dict, vae_list
65+
res = {}
66+
candidates = [
67+
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
68+
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
69+
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
70+
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
71+
]
72+
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
73+
candidates.append(shared.cmd_opts.vae_path)
74+
for filepath in candidates:
75+
name = get_filename(filepath)
76+
res[name] = filepath
77+
vae_list.clear()
78+
vae_list.extend(default_vae_list)
79+
vae_list.extend(list(res.keys()))
80+
vae_dict.clear()
81+
vae_dict.update(res)
82+
vae_dict.update(default_vae_dict)
83+
return vae_list
84+
85+
86+
def resolve_vae(checkpoint_file, vae_file="auto"):
87+
global first_load, vae_dict, vae_list
88+
89+
# if vae_file argument is provided, it takes priority, but not saved
90+
if vae_file and vae_file not in default_vae_list:
91+
if not os.path.isfile(vae_file):
92+
vae_file = "auto"
93+
print("VAE provided as function argument doesn't exist")
94+
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
95+
if first_load and shared.cmd_opts.vae_path is not None:
96+
if os.path.isfile(shared.cmd_opts.vae_path):
97+
vae_file = shared.cmd_opts.vae_path
98+
shared.opts.data['sd_vae'] = get_filename(vae_file)
99+
else:
100+
print("VAE provided as command line argument doesn't exist")
101+
# else, we load from settings
102+
if vae_file == "auto" and shared.opts.sd_vae is not None:
103+
# if saved VAE settings isn't recognized, fallback to auto
104+
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
105+
# if VAE selected but not found, fallback to auto
106+
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
107+
vae_file = "auto"
108+
print("Selected VAE doesn't exist")
109+
# vae-path cmd arg takes priority for auto
110+
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
111+
if os.path.isfile(shared.cmd_opts.vae_path):
112+
vae_file = shared.cmd_opts.vae_path
113+
print("Using VAE provided as command line argument")
114+
# if still not found, try look for ".vae.pt" beside model
115+
model_path = os.path.splitext(checkpoint_file)[0]
116+
if vae_file == "auto":
117+
vae_file_try = model_path + ".vae.pt"
118+
if os.path.isfile(vae_file_try):
119+
vae_file = vae_file_try
120+
print("Using VAE found beside selected model")
121+
# if still not found, try look for ".vae.ckpt" beside model
122+
if vae_file == "auto":
123+
vae_file_try = model_path + ".vae.ckpt"
124+
if os.path.isfile(vae_file_try):
125+
vae_file = vae_file_try
126+
print("Using VAE found beside selected model")
127+
# No more fallbacks for auto
128+
if vae_file == "auto":
129+
vae_file = None
130+
# Last check, just because
131+
if vae_file and not os.path.exists(vae_file):
132+
vae_file = None
133+
134+
return vae_file
135+
136+
137+
def load_vae(model, vae_file=None):
138+
global first_load, vae_dict, vae_list, loaded_vae_file
139+
# save_settings = False
140+
141+
if vae_file:
142+
print(f"Loading VAE weights from: {vae_file}")
143+
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
144+
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
145+
load_vae_dict(model, vae_dict_1)
146+
147+
# If vae used is not in dict, update it
148+
# It will be removed on refresh though
149+
vae_opt = get_filename(vae_file)
150+
if vae_opt not in vae_dict:
151+
vae_dict[vae_opt] = vae_file
152+
vae_list.append(vae_opt)
153+
154+
loaded_vae_file = vae_file
155+
156+
"""
157+
# Save current VAE to VAE settings, maybe? will it work?
158+
if save_settings:
159+
if vae_file is None:
160+
vae_opt = "None"
161+
162+
# shared.opts.sd_vae = vae_opt
163+
"""
164+
165+
first_load = False
166+
167+
168+
# don't call this from outside
169+
def load_vae_dict(model, vae_dict_1=None):
170+
if vae_dict_1:
171+
store_base_vae(model)
172+
model.first_stage_model.load_state_dict(vae_dict_1)
173+
else:
174+
restore_base_vae()
175+
model.first_stage_model.to(devices.dtype_vae)
176+
177+
178+
def reload_vae_weights(sd_model=None, vae_file="auto"):
179+
from modules import lowvram, devices, sd_hijack
180+
181+
if not sd_model:
182+
sd_model = shared.sd_model
183+
184+
checkpoint_info = sd_model.sd_checkpoint_info
185+
checkpoint_file = checkpoint_info.filename
186+
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
187+
188+
if loaded_vae_file == vae_file:
189+
return
190+
191+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
192+
lowvram.send_everything_to_cpu()
193+
else:
194+
sd_model.to(devices.cpu)
195+
196+
sd_hijack.model_hijack.undo_hijack(sd_model)
197+
198+
load_vae(sd_model, vae_file)
199+
200+
sd_hijack.model_hijack.hijack(sd_model)
201+
script_callbacks.model_loaded_callback(sd_model)
202+
203+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
204+
sd_model.to(devices.device)
205+
206+
print(f"VAE Weights loaded.")
207+
return sd_model

modules/shared.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import modules.sd_models
1616
import modules.styles
1717
import modules.devices as devices
18-
from modules import sd_samplers, sd_models, localization
18+
from modules import sd_samplers, sd_models, localization, sd_vae
1919
from modules.hypernetworks import hypernetwork
2020
from modules.paths import models_path, script_path, sd_path
2121

@@ -319,6 +319,7 @@ def options_section(section_identifier, options_dict):
319319
options_templates.update(options_section(('sd', "Stable Diffusion"), {
320320
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
321321
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
322+
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
322323
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
323324
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
324325
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -437,11 +438,12 @@ def load(self, filename):
437438
if bad_settings > 0:
438439
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
439440

440-
def onchange(self, key, func):
441+
def onchange(self, key, func, call=True):
441442
item = self.data_labels.get(key)
442443
item.onchange = func
443444

444-
func()
445+
if call:
446+
func()
445447

446448
def dumpjson(self):
447449
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}

style.css

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ input[type="range"]{
501501
padding: 0;
502502
}
503503

504-
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
504+
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
505505
max-width: 2.5em;
506506
min-width: 2.5em;
507507
height: 2.4em;

webui.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import modules.scripts
2222
import modules.sd_hijack
2323
import modules.sd_models
24+
import modules.sd_vae
2425
import modules.shared as shared
2526
import modules.txt2img
2627
import modules.script_callbacks
@@ -77,8 +78,10 @@ def initialize():
7778

7879
modules.scripts.load_scripts()
7980

81+
modules.sd_vae.refresh_vae_list()
8082
modules.sd_models.load_model()
8183
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
84+
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
8285
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
8386
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
8487

0 commit comments

Comments
 (0)