@@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
163
163
checkpoint_file = checkpoint_info .filename
164
164
sd_model_hash = checkpoint_info .hash
165
165
166
- if shared .opts .sd_checkpoint_cache > 0 and hasattr (model , "sd_checkpoint_info" ):
166
+ cache_enabled = shared .opts .sd_checkpoint_cache > 0
167
+
168
+ if cache_enabled :
167
169
sd_vae .restore_base_vae (model )
168
- checkpoints_loaded [model .sd_checkpoint_info ] = model .state_dict ().copy ()
169
170
170
171
vae_file = sd_vae .resolve_vae (checkpoint_file , vae_file = vae_file )
171
172
172
- if checkpoint_info not in checkpoints_loaded :
173
+ if cache_enabled and checkpoint_info in checkpoints_loaded :
174
+ # use checkpoint cache
175
+ vae_name = sd_vae .get_filename (vae_file ) if vae_file else None
176
+ vae_message = f" with { vae_name } VAE" if vae_name else ""
177
+ print (f"Loading weights [{ sd_model_hash } ]{ vae_message } from cache" )
178
+ model .load_state_dict (checkpoints_loaded [checkpoint_info ])
179
+ else :
180
+ # load from file
173
181
print (f"Loading weights [{ sd_model_hash } ] from { checkpoint_file } " )
174
182
175
183
pl_sd = torch .load (checkpoint_file , map_location = shared .weight_load_location )
@@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
180
188
del pl_sd
181
189
model .load_state_dict (sd , strict = False )
182
190
del sd
191
+
192
+ if cache_enabled :
193
+ # cache newly loaded model
194
+ checkpoints_loaded [checkpoint_info ] = model .state_dict ().copy ()
183
195
184
196
if shared .cmd_opts .opt_channelslast :
185
197
model .to (memory_format = torch .channels_last )
@@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
199
211
200
212
model .first_stage_model .to (devices .dtype_vae )
201
213
202
- else :
203
- vae_name = sd_vae .get_filename (vae_file ) if vae_file else None
204
- vae_message = f" with { vae_name } VAE" if vae_name else ""
205
- print (f"Loading weights [{ sd_model_hash } ]{ vae_message } from cache" )
206
- model .load_state_dict (checkpoints_loaded [checkpoint_info ])
207
-
208
- if shared .opts .sd_checkpoint_cache > 0 :
214
+ # clean up cache if limit is reached
215
+ if cache_enabled :
209
216
while len (checkpoints_loaded ) > shared .opts .sd_checkpoint_cache :
210
217
checkpoints_loaded .popitem (last = False ) # LRU
211
218
0 commit comments