Skip to content

Commit f2a5cbe

Browse files
committed
fix #3986 breaking --no-half-vae
1 parent 675b51e commit f2a5cbe

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

modules/sd_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
183183
model.to(memory_format=torch.channels_last)
184184

185185
if not shared.cmd_opts.no_half:
186+
vae = model.first_stage_model
187+
188+
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
189+
if shared.cmd_opts.no_half_vae:
190+
model.first_stage_model = None
191+
186192
model.half()
193+
model.first_stage_model = vae
187194

188195
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
189196
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
190197

198+
model.first_stage_model.to(devices.dtype_vae)
199+
191200
if shared.opts.sd_checkpoint_cache > 0:
192201
# if PR #4035 were to get merged, restore base VAE first before caching
193202
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()

0 commit comments

Comments
 (0)