@@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd):
160
160
161
161
vae_ignore_keys = {"model_ema.decay" , "model_ema.num_updates" }
162
162
163
- def load_model_weights (model , checkpoint_info , force = False ):
163
+ def load_model_weights (model , checkpoint_info , vae_file = "auto" ):
164
164
checkpoint_file = checkpoint_info .filename
165
165
sd_model_hash = checkpoint_info .hash
166
166
167
- if force or checkpoint_info not in checkpoints_loaded :
167
+ vae_file = sd_vae .resolve_vae (checkpoint_file , vae_file = vae_file )
168
+
169
+ checkpoint_key = (checkpoint_info , vae_file )
170
+
171
+ if checkpoint_key not in checkpoints_loaded :
168
172
print (f"Loading weights [{ sd_model_hash } ] from { checkpoint_file } " )
169
173
170
174
pl_sd = torch .load (checkpoint_file , map_location = shared .weight_load_location )
@@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False):
185
189
devices .dtype = torch .float32 if shared .cmd_opts .no_half else torch .float16
186
190
devices .dtype_vae = torch .float32 if shared .cmd_opts .no_half or shared .cmd_opts .no_half_vae else torch .float16
187
191
188
- sd_vae .load_vae (model , checkpoint_file )
192
+ sd_vae .load_vae (model , vae_file )
189
193
model .first_stage_model .to (devices .dtype_vae )
190
194
191
195
if shared .opts .sd_checkpoint_cache > 0 :
192
- checkpoints_loaded [checkpoint_info ] = model .state_dict ().copy ()
196
+ checkpoints_loaded [checkpoint_key ] = model .state_dict ().copy ()
193
197
while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache :
194
198
checkpoints_loaded .popitem (last = False ) # LRU
195
199
else :
196
- print (f"Loading weights [{ sd_model_hash } ] from cache" )
197
- checkpoints_loaded .move_to_end (checkpoint_info )
198
- model .load_state_dict (checkpoints_loaded [checkpoint_info ])
200
+ vae_name = sd_vae .get_filename (vae_file )
201
+ print (f"Loading weights [{ sd_model_hash } ] with { vae_name } VAE from cache" )
202
+ checkpoints_loaded .move_to_end (checkpoint_key )
203
+ model .load_state_dict (checkpoints_loaded [checkpoint_key ])
199
204
200
205
model .sd_model_hash = sd_model_hash
201
206
model .sd_model_checkpoint = checkpoint_file
202
207
model .sd_checkpoint_info = checkpoint_info
203
208
204
209
205
- def load_model (checkpoint_info = None , force = False ):
210
+ def load_model (checkpoint_info = None ):
206
211
from modules import lowvram , sd_hijack
207
212
checkpoint_info = checkpoint_info or select_checkpoint ()
208
213
@@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False):
223
228
224
229
do_inpainting_hijack ()
225
230
sd_model = instantiate_from_config (sd_config .model )
226
- load_model_weights (sd_model , checkpoint_info , force = force )
231
+ load_model_weights (sd_model , checkpoint_info )
227
232
228
233
if shared .cmd_opts .lowvram or shared .cmd_opts .medvram :
229
234
lowvram .setup_for_low_vram (sd_model , shared .cmd_opts .medvram )
@@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False):
250
255
251
256
if sd_model .sd_checkpoint_info .config != checkpoint_info .config or should_hijack_inpainting (checkpoint_info ) != should_hijack_inpainting (sd_model .sd_checkpoint_info ):
252
257
checkpoints_loaded .clear ()
253
- load_model (checkpoint_info , force = force )
258
+ load_model (checkpoint_info )
254
259
return shared .sd_model
255
260
256
261
if shared .cmd_opts .lowvram or shared .cmd_opts .medvram :
@@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False):
260
265
261
266
sd_hijack .model_hijack .undo_hijack (sd_model )
262
267
263
- load_model_weights (sd_model , checkpoint_info , force = force )
268
+ load_model_weights (sd_model , checkpoint_info )
264
269
265
270
sd_hijack .model_hijack .hijack (sd_model )
266
271
script_callbacks .model_loaded_callback (sd_model )
0 commit comments