Skip to content

Commit af758e9

Browse files
committed
Unload sd_model before loading the other
1 parent 5c9b362 commit af758e9

File tree

5 files changed

+34
-10
lines changed

5 files changed

+34
-10
lines changed

modules/lowvram.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,18 @@ def send_me_to_gpu(module, _):
3838
# see below for register_forward_pre_hook;
3939
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
4040
# useless here, and we just replace those methods
41-
def first_stage_model_encode_wrap(self, encoder, x):
42-
send_me_to_gpu(self, None)
43-
return encoder(x)
4441

45-
def first_stage_model_decode_wrap(self, decoder, z):
46-
send_me_to_gpu(self, None)
47-
return decoder(z)
42+
first_stage_model = sd_model.first_stage_model
43+
first_stage_model_encode = sd_model.first_stage_model.encode
44+
first_stage_model_decode = sd_model.first_stage_model.decode
45+
46+
def first_stage_model_encode_wrap(x):
47+
send_me_to_gpu(first_stage_model, None)
48+
return first_stage_model_encode(x)
49+
50+
def first_stage_model_decode_wrap(z):
51+
send_me_to_gpu(first_stage_model, None)
52+
return first_stage_model_decode(z)
4853

4954
# remove three big modules, cond, first_stage, and unet from the model and then
5055
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def first_stage_model_decode_wrap(self, decoder, z):
5661
# register hooks for those the first two models
5762
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
5863
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
59-
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
60-
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
64+
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
65+
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
6166
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
6267

6368
if use_medvram:

modules/processing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ def infotext(iteration=0, position_in_batch=0):
597597
if p.scripts is not None:
598598
p.scripts.postprocess(p, res)
599599

600+
p.sd_model = None
601+
p.sampler = None
602+
600603
return res
601604

602605

modules/sd_hijack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def undo_hijack(self, m):
9494
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
9595
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
9696

97+
self.layers = None
98+
self.circular_enabled = False
99+
self.clip = None
100+
97101
def apply_circular(self, enable):
98102
if self.circular_enabled == enable:
99103
return

modules/sd_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import os.path
33
import sys
4+
import gc
45
from collections import namedtuple
56
import torch
67
import re
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
220221
if checkpoint_info.config != shared.cmd_opts.config:
221222
print(f"Loading config from: {checkpoint_info.config}")
222223

224+
if shared.sd_model:
225+
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
226+
shared.sd_model = None
227+
gc.collect()
228+
devices.torch_gc()
229+
223230
sd_config = OmegaConf.load(checkpoint_info.config)
224231

225232
if should_hijack_inpainting(checkpoint_info):
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
233240
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
234241

235242
do_inpainting_hijack()
243+
236244
sd_model = instantiate_from_config(sd_config.model)
237245
load_model_weights(sd_model, checkpoint_info)
238246

@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
252260
return sd_model
253261

254262

255-
def reload_model_weights(sd_model, info=None):
263+
def reload_model_weights(sd_model=None, info=None):
256264
from modules import lowvram, devices, sd_hijack
257265
checkpoint_info = info or select_checkpoint()
258266

267+
if not sd_model:
268+
sd_model = shared.sd_model
269+
259270
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
260271
return
261272

262273
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
274+
del sd_model
263275
checkpoints_loaded.clear()
264276
load_model(checkpoint_info)
265277
return shared.sd_model

webui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def initialize():
7777
modules.scripts.load_scripts()
7878

7979
modules.sd_models.load_model()
80-
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
80+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
8181
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
8282
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
8383

0 commit comments

Comments
 (0)