Skip to content

Commit b50ff4f

Browse files
committed
Reduce peak memory usage when changing models
A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM.
1 parent 737eb28 commit b50ff4f

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

modules/sd_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
170170
print(f"Global Step: {pl_sd['global_step']}")
171171

172172
sd = get_state_dict_from_checkpoint(pl_sd)
173-
missing, extra = model.load_state_dict(sd, strict=False)
173+
del pl_sd
174+
model.load_state_dict(sd, strict=False)
175+
del sd
174176

175177
if shared.cmd_opts.opt_channelslast:
176178
model.to(memory_format=torch.channels_last)
@@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info):
194196

195197
model.first_stage_model.to(devices.dtype_vae)
196198

197-
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
198-
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
199-
checkpoints_loaded.popitem(last=False) # LRU
199+
if shared.opts.sd_checkpoint_cache > 0:
200+
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
201+
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
202+
checkpoints_loaded.popitem(last=False) # LRU
200203
else:
201204
print(f"Loading weights [{sd_model_hash}] from cache")
202205
checkpoints_loaded.move_to_end(checkpoint_info)

0 commit comments

Comments
 (0)