1
1
import torch
2
2
import os
3
3
from collections import namedtuple
4
- from modules import shared , devices
4
+ from modules import shared , devices , script_callbacks
5
5
from modules .paths import models_path
6
6
import glob
7
7
8
+
8
9
model_dir = "Stable-diffusion"
9
10
model_path = os .path .abspath (os .path .join (models_path , model_dir ))
10
11
vae_dir = "VAE"
11
12
vae_path = os .path .abspath (os .path .join (models_path , vae_dir ))
12
13
14
+
13
15
vae_ignore_keys = {"model_ema.decay" , "model_ema.num_updates" }
16
+
17
+
14
18
default_vae_dict = {"auto" : "auto" , "None" : "None" }
15
19
default_vae_list = ["auto" , "None" ]
20
+
21
+
16
22
default_vae_values = [default_vae_dict [x ] for x in default_vae_list ]
17
23
vae_dict = dict (default_vae_dict )
18
24
vae_list = list (default_vae_list )
19
25
first_load = True
20
26
27
+
28
+ base_vae = None
29
+ loaded_vae_file = None
30
+ checkpoint_info = None
31
+
32
+
33
+ def get_base_vae (model ):
34
+ if base_vae is not None and checkpoint_info == model .sd_checkpoint_info and model :
35
+ return base_vae
36
+ return None
37
+
38
+
39
+ def store_base_vae (model ):
40
+ global base_vae , checkpoint_info
41
+ if checkpoint_info != model .sd_checkpoint_info :
42
+ base_vae = model .first_stage_model .state_dict ().copy ()
43
+ checkpoint_info = model .sd_checkpoint_info
44
+
45
+
46
+ def delete_base_vae ():
47
+ global base_vae , checkpoint_info
48
+ base_vae = None
49
+ checkpoint_info = None
50
+
51
+
52
+ def restore_base_vae (model ):
53
+ global base_vae , checkpoint_info
54
+ if base_vae is not None and checkpoint_info == model .sd_checkpoint_info :
55
+ load_vae_dict (model , base_vae )
56
+ delete_base_vae ()
57
+
58
+
21
59
def get_filename (filepath ):
22
60
return os .path .splitext (os .path .basename (filepath ))[0 ]
23
61
62
+
24
63
def refresh_vae_list (vae_path = vae_path , model_path = model_path ):
25
64
global vae_dict , vae_list
26
65
res = {}
@@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
43
82
vae_dict .update (res )
44
83
return vae_list
45
84
85
+
46
86
def resolve_vae (checkpoint_file , vae_file = "auto" ):
47
87
global first_load , vae_dict , vae_list
48
88
# save_settings = False
@@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"):
96
136
97
137
return vae_file
98
138
99
- def load_vae (model , vae_file ):
100
- global first_load , vae_dict , vae_list
139
+
140
+ def load_vae (model , vae_file = None ):
141
+ global first_load , vae_dict , vae_list , loaded_vae_file
101
142
# save_settings = False
102
143
103
144
if vae_file :
104
145
print (f"Loading VAE weights from: { vae_file } " )
105
146
vae_ckpt = torch .load (vae_file , map_location = shared .weight_load_location )
106
147
vae_dict_1 = {k : v for k , v in vae_ckpt ["state_dict" ].items () if k [0 :4 ] != "loss" and k not in vae_ignore_keys }
107
- model . first_stage_model . load_state_dict ( vae_dict_1 )
148
+ load_vae_dict ( model , vae_dict_1 )
108
149
109
- # If vae used is not in dict, update it
110
- # It will be removed on refresh though
111
- if vae_file is not None :
150
+ # If vae used is not in dict, update it
151
+ # It will be removed on refresh though
112
152
vae_opt = get_filename (vae_file )
113
153
if vae_opt not in vae_dict :
114
154
vae_dict [vae_opt ] = vae_file
115
155
vae_list .append (vae_opt )
116
156
157
+ loaded_vae_file = vae_file
158
+
117
159
"""
118
160
# Save current VAE to VAE settings, maybe? will it work?
119
161
if save_settings:
@@ -124,4 +166,45 @@ def load_vae(model, vae_file):
124
166
"""
125
167
126
168
first_load = False
169
+
170
+
171
+ # don't call this from outside
172
+ def load_vae_dict (model , vae_dict_1 = None ):
173
+ if vae_dict_1 :
174
+ store_base_vae (model )
175
+ model .first_stage_model .load_state_dict (vae_dict_1 )
176
+ else :
177
+ restore_base_vae ()
127
178
model .first_stage_model .to (devices .dtype_vae )
179
+
180
+
181
+ def reload_vae_weights (sd_model = None , vae_file = "auto" ):
182
+ from modules import lowvram , devices , sd_hijack
183
+
184
+ if not sd_model :
185
+ sd_model = shared .sd_model
186
+
187
+ checkpoint_info = sd_model .sd_checkpoint_info
188
+ checkpoint_file = checkpoint_info .filename
189
+ vae_file = resolve_vae (checkpoint_file , vae_file = vae_file )
190
+
191
+ if loaded_vae_file == vae_file :
192
+ return
193
+
194
+ if shared .cmd_opts .lowvram or shared .cmd_opts .medvram :
195
+ lowvram .send_everything_to_cpu ()
196
+ else :
197
+ sd_model .to (devices .cpu )
198
+
199
+ sd_hijack .model_hijack .undo_hijack (sd_model )
200
+
201
+ load_vae (sd_model , vae_file )
202
+
203
+ sd_hijack .model_hijack .hijack (sd_model )
204
+ script_callbacks .model_loaded_callback (sd_model )
205
+
206
+ if not shared .cmd_opts .lowvram and not shared .cmd_opts .medvram :
207
+ sd_model .to (devices .device )
208
+
209
+ print (f"VAE Weights loaded." )
210
+ return sd_model
0 commit comments