@@ -163,11 +163,11 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
163
163
checkpoint_file = checkpoint_info .filename
164
164
sd_model_hash = checkpoint_info .hash
165
165
166
- vae_file = sd_vae .resolve_vae (checkpoint_file , vae_file = vae_file )
166
+ if shared .opts .sd_checkpoint_cache > 0 and hasattr (model , "sd_checkpoint_info" ):
167
+ sd_vae .restore_base_vae (model )
168
+ checkpoints_loaded [model .sd_checkpoint_info ] = model .state_dict ().copy ()
167
169
168
- checkpoint_key = checkpoint_info
169
-
170
- if checkpoint_key not in checkpoints_loaded :
170
+ if checkpoint_info not in checkpoints_loaded :
171
171
print (f"Loading weights [{ sd_model_hash } ] from { checkpoint_file } " )
172
172
173
173
pl_sd = torch .load (checkpoint_file , map_location = shared .weight_load_location )
@@ -197,18 +197,15 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
197
197
198
198
model .first_stage_model .to (devices .dtype_vae )
199
199
200
- if shared .opts .sd_checkpoint_cache > 0 :
201
- # if PR #4035 were to get merged, restore base VAE first before caching
202
- checkpoints_loaded [checkpoint_key ] = model .state_dict ().copy ()
203
- while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache :
204
- checkpoints_loaded .popitem (last = False ) # LRU
205
-
206
200
else :
207
201
vae_name = sd_vae .get_filename (vae_file ) if vae_file else None
208
202
vae_message = f" with { vae_name } VAE" if vae_name else ""
209
203
print (f"Loading weights [{ sd_model_hash } ]{ vae_message } from cache" )
210
- checkpoints_loaded .move_to_end (checkpoint_key )
211
- model .load_state_dict (checkpoints_loaded [checkpoint_key ])
204
+ model .load_state_dict (checkpoints_loaded [checkpoint_info ])
205
+
206
+ if shared .opts .sd_checkpoint_cache > 0 :
207
+ while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache :
208
+ checkpoints_loaded .popitem (last = False ) # LRU
212
209
213
210
model .sd_model_hash = sd_model_hash
214
211
model .sd_model_checkpoint = checkpoint_file
0 commit comments