Skip to content

Commit 794ef86

Browse files
committed
fix incorrect loading of external VAEs
- Closes #3073
1 parent a1ed225 commit 794ef86

File tree

3 files changed

+34
-66
lines changed

3 files changed

+34
-66
lines changed

ldm/invoke/CLI.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -776,14 +776,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
776776
original_config_file = Path(model_info["config"])
777777
model_name = model_name_or_path
778778
model_description = model_info["description"]
779-
vae = model_info.get("vae")
779+
vae_path = model_info.get("vae")
780780
else:
781781
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
782782
return
783-
if vae and (vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem)):
784-
vae_repo = dict(repo_id=vae_repo)
785-
else:
786-
vae_repo = None
787783
model_name = manager.convert_and_import(
788784
ckpt_path,
789785
diffusers_path=Path(
@@ -792,7 +788,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
792788
model_name=model_name,
793789
model_description=model_description,
794790
original_config_file=original_config_file,
795-
vae=vae_repo,
791+
vae_path=vae_path,
796792
)
797793
else:
798794
try:

ldm/invoke/ckpt_to_diffuser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,10 +1037,10 @@ def convert_open_clip_checkpoint(checkpoint):
10371037
return text_model
10381038

10391039
def replace_checkpoint_vae(checkpoint, vae_path:str):
1040-
if vae_path.endswith(".safetensors"):
1041-
vae_ckpt = load_file(vae_path)
1042-
else:
1040+
if Path(vae_path).suffix in ['.pt','.ckpt']:
10431041
vae_ckpt = torch.load(vae_path, map_location="cpu")
1042+
else:
1043+
vae_ckpt = load_file(vae_path)
10441044
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
10451045
for vae_key in state_dict:
10461046
new_key = f'first_stage_model.{vae_key}'

ldm/invoke/model_manager.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,7 @@ class SDLegacyType(Enum):
4646
V2_v = 5
4747
UNKNOWN = 99
4848

49-
5049
DEFAULT_MAX_MODELS = 2
51-
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
52-
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
53-
}
54-
5550

5651
class ModelManager(object):
5752
def __init__(
@@ -382,6 +377,12 @@ def _load_ckpt_model(self, model_name, mconfig):
382377
# check whether this is a v2 file and force conversion
383378
convert = Globals.ckpt_convert or self.is_v2_config(config)
384379

380+
# get the path to the custom vae, if any
381+
vae_path = None
382+
if vae:
383+
path = Path(vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae)))
384+
vae_path = path if path.exists() else None
385+
385386
# if converting automatically to diffusers, then we do the conversion and return
386387
# a diffusers pipeline
387388
if convert:
@@ -390,15 +391,18 @@ def _load_ckpt_model(self, model_name, mconfig):
390391
)
391392
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
392393

393-
self.offload_model(self.current_model)
394-
if vae_config := self._choose_diffusers_vae(model_name):
395-
vae = self._load_vae(vae_config)
394+
try:
395+
if self.list_models()[self.current_model]['status'] == 'active':
396+
self.offload_model(self.current_model)
397+
except Exception:
398+
pass
399+
396400
if self._has_cuda():
397401
torch.cuda.empty_cache()
398402
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
399403
checkpoint_path=weights,
400404
original_config_file=config,
401-
vae=vae,
405+
vae_path=vae_path,
402406
return_generator_pipeline=True,
403407
precision=torch.float16
404408
if self.precision == "float16"
@@ -453,20 +457,17 @@ def _load_ckpt_model(self, model_name, mconfig):
453457
print(" | Using more accurate float32 precision")
454458

455459
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
456-
if vae:
457-
if not os.path.isabs(vae):
458-
vae = os.path.normpath(os.path.join(Globals.root, vae))
459-
if os.path.exists(vae):
460-
print(f" | Loading VAE weights from: {vae}")
461-
if vae.endswith((".ckpt", ".pt")):
462-
self.scan_model(vae, vae)
463-
vae_ckpt = torch.load(vae, map_location="cpu")
464-
else:
465-
vae_ckpt = safetensors.torch.load_file(vae)
466-
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
467-
model.first_stage_model.load_state_dict(vae_dict, strict=False)
460+
if vae_path:
461+
print(f" | Loading VAE weights from: {vae}")
462+
if vae_path.suffix in [".ckpt", ".pt"]:
463+
self.scan_model(vae_path.name, vae_path)
464+
vae_ckpt = torch.load(vae_path, map_location="cpu")
468465
else:
469-
print(f" | VAE file {vae} not found. Skipping.")
466+
vae_ckpt = safetensors.torch.load_file(vae_path)
467+
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
468+
model.first_stage_model.load_state_dict(vae_dict, strict=False)
469+
else:
470+
print(f" | VAE file {vae} not found. Skipping.")
470471

471472
model.to(self.device)
472473
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@@ -820,7 +821,6 @@ def heuristic_import(
820821
print(f" | {thing} appears to be a diffusers file on disk")
821822
model_name = self.import_diffuser_model(
822823
thing,
823-
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
824824
model_name=model_name,
825825
description=description,
826826
commit_to_conf=commit_to_conf,
@@ -930,12 +930,11 @@ def heuristic_import(
930930
if (model_path.with_suffix(f".vae.{suffix}")).exists():
931931
vae_path = model_path.with_suffix(f".vae.{suffix}")
932932
print(f" | Using VAE file {vae_path.name}")
933-
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
934-
935933
if convert:
936934
diffuser_path = Path(
937935
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
938936
)
937+
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
939938
model_name = self.convert_and_import(
940939
model_path,
941940
diffusers_path=diffuser_path,
@@ -1008,14 +1007,17 @@ def convert_and_import(
10081007
try:
10091008
# By passing the specified VAE to the conversion function, the autoencoder
10101009
# will be built into the model rather than tacked on afterward via the config file
1011-
vae_model = self._load_vae(vae) if vae else None
1010+
vae_model=None
1011+
if vae:
1012+
vae_model=self._load_vae(vae)
1013+
vae_path=None
10121014
convert_ckpt_to_diffusers(
10131015
ckpt_path,
10141016
diffusers_path,
10151017
extract_ema=True,
10161018
original_config_file=original_config_file,
10171019
vae=vae_model,
1018-
vae_path=str(vae_path) if vae_path else None,
1020+
vae_path=vae_path,
10191021
scan_needed=scan_needed,
10201022
)
10211023
print(
@@ -1062,36 +1064,6 @@ def search_models(self, search_folder):
10621064

10631065
return search_folder, found_models
10641066

1065-
def _choose_diffusers_vae(
1066-
self, model_name: str, vae: str = None
1067-
) -> Union[dict, str]:
1068-
# In the event that the original entry is using a custom ckpt VAE, we try to
1069-
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
1070-
# I would prefer to do this differently: We load the ckpt model into memory, swap the
1071-
# VAE in memory, and then pass that to convert_ckpt_to_diffusers() so that the swapped
1072-
# VAE is built into the model. However, when I tried this I got obscure key errors.
1073-
if vae:
1074-
return vae
1075-
if model_name in self.config and (
1076-
vae_ckpt_path := self.model_info(model_name).get("vae", None)
1077-
):
1078-
vae_basename = Path(vae_ckpt_path).stem
1079-
diffusers_vae = None
1080-
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
1081-
print(
1082-
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
1083-
)
1084-
vae = {"repo_id": diffusers_vae}
1085-
else:
1086-
print(
1087-
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
1088-
)
1089-
print(
1090-
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
1091-
)
1092-
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
1093-
return vae
1094-
10951067
def _make_cache_room(self) -> None:
10961068
num_loaded_models = len(self.models)
10971069
if num_loaded_models >= self.max_loaded_models:

0 commit comments

Comments
 (0)