@@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info):
165
165
checkpoint_file = checkpoint_info .filename
166
166
sd_model_hash = checkpoint_info .hash
167
167
168
+ if shared .opts .sd_checkpoint_cache > 0 and hasattr (model , "sd_checkpoint_info" ):
169
+ checkpoints_loaded [model .sd_checkpoint_info ] = model .state_dict ().copy ()
170
+
168
171
if checkpoint_info not in checkpoints_loaded :
169
172
print (f"Loading weights [{ sd_model_hash } ] from { checkpoint_file } " )
170
173
@@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info):
198
201
model .first_stage_model .load_state_dict (vae_dict )
199
202
200
203
model .first_stage_model .to (devices .dtype_vae )
201
-
202
- if shared .opts .sd_checkpoint_cache > 0 :
203
- checkpoints_loaded [checkpoint_info ] = model .state_dict ().copy ()
204
- while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache + 1 :
205
- checkpoints_loaded .popitem (last = False ) # LRU
206
204
else :
207
205
print (f"Loading weights [{ sd_model_hash } ] from cache" )
208
- checkpoints_loaded .move_to_end (checkpoint_info )
209
206
model .load_state_dict (checkpoints_loaded [checkpoint_info ])
210
207
208
+ if shared .opts .sd_checkpoint_cache > 0 :
209
+ while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache :
210
+ checkpoints_loaded .popitem (last = False ) # LRU
211
+
211
212
model .sd_model_hash = sd_model_hash
212
213
model .sd_model_checkpoint = checkpoint_file
213
214
model .sd_checkpoint_info = checkpoint_info
0 commit comments