@@ -46,12 +46,7 @@ class SDLegacyType(Enum):
4646 V2_v = 5
4747 UNKNOWN = 99
4848
49-
5049DEFAULT_MAX_MODELS = 2
51- VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
52- "vae-ft-mse-840000-ema-pruned" : "stabilityai/sd-vae-ft-mse" ,
53- }
54-
5550
5651class ModelManager (object ):
5752 def __init__ (
@@ -382,6 +377,12 @@ def _load_ckpt_model(self, model_name, mconfig):
382377 # check whether this is a v2 file and force conversion
383378 convert = Globals .ckpt_convert or self .is_v2_config (config )
384379
380+ # get the path to the custom vae, if any
381+ vae_path = None
382+ if vae :
383+ path = Path (vae if os .path .isabs (vae ) else os .path .normpath (os .path .join (Globals .root , vae )))
384+ vae_path = path if path .exists () else None
385+
385386 # if converting automatically to diffusers, then we do the conversion and return
386387 # a diffusers pipeline
387388 if convert :
@@ -390,15 +391,18 @@ def _load_ckpt_model(self, model_name, mconfig):
390391 )
391392 from ldm .invoke .ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
392393
393- self .offload_model (self .current_model )
394- if vae_config := self ._choose_diffusers_vae (model_name ):
395- vae = self ._load_vae (vae_config )
394+ try :
395+ if self .list_models ()[self .current_model ]['status' ] == 'active' :
396+ self .offload_model (self .current_model )
397+ except Exception :
398+ pass
399+
396400 if self ._has_cuda ():
397401 torch .cuda .empty_cache ()
398402 pipeline = load_pipeline_from_original_stable_diffusion_ckpt (
399403 checkpoint_path = weights ,
400404 original_config_file = config ,
401- vae = vae ,
405+ vae_path = vae_path ,
402406 return_generator_pipeline = True ,
403407 precision = torch .float16
404408 if self .precision == "float16"
@@ -453,20 +457,17 @@ def _load_ckpt_model(self, model_name, mconfig):
453457 print (" | Using more accurate float32 precision" )
454458
455459 # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
456- if vae :
457- if not os .path .isabs (vae ):
458- vae = os .path .normpath (os .path .join (Globals .root , vae ))
459- if os .path .exists (vae ):
460- print (f" | Loading VAE weights from: { vae } " )
461- if vae .endswith ((".ckpt" , ".pt" )):
462- self .scan_model (vae , vae )
463- vae_ckpt = torch .load (vae , map_location = "cpu" )
464- else :
465- vae_ckpt = safetensors .torch .load_file (vae )
466- vae_dict = {k : v for k , v in vae_ckpt .items () if k [0 :4 ] != "loss" }
467- model .first_stage_model .load_state_dict (vae_dict , strict = False )
460+ if vae_path :
461+ print (f" | Loading VAE weights from: { vae } " )
462+ if vae_path .suffix in [".ckpt" , ".pt" ]:
463+ self .scan_model (vae_path .name , vae_path )
464+ vae_ckpt = torch .load (vae_path , map_location = "cpu" )
468465 else :
469- print (f" | VAE file { vae } not found. Skipping." )
466+ vae_ckpt = safetensors .torch .load_file (vae_path )
467+ vae_dict = {k : v for k , v in vae_ckpt ["state_dict" ].items () if k [0 :4 ] != "loss" }
468+ model .first_stage_model .load_state_dict (vae_dict , strict = False )
469+ else :
470+ print (f" | VAE file { vae } not found. Skipping." )
470471
471472 model .to (self .device )
472473 # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@@ -820,7 +821,6 @@ def heuristic_import(
820821 print (f" | { thing } appears to be a diffusers file on disk" )
821822 model_name = self .import_diffuser_model (
822823 thing ,
823- vae = dict (repo_id = "stabilityai/sd-vae-ft-mse" ),
824824 model_name = model_name ,
825825 description = description ,
826826 commit_to_conf = commit_to_conf ,
@@ -930,12 +930,11 @@ def heuristic_import(
930930 if (model_path .with_suffix (f".vae.{ suffix } " )).exists ():
931931 vae_path = model_path .with_suffix (f".vae.{ suffix } " )
932932 print (f" | Using VAE file { vae_path .name } " )
933- vae = None if vae_path else dict (repo_id = "stabilityai/sd-vae-ft-mse" )
934-
935933 if convert :
936934 diffuser_path = Path (
937935 Globals .root , "models" , Globals .converted_ckpts_dir , model_path .stem
938936 )
937+ vae = None if vae_path else dict (repo_id = "stabilityai/sd-vae-ft-mse" )
939938 model_name = self .convert_and_import (
940939 model_path ,
941940 diffusers_path = diffuser_path ,
@@ -1008,14 +1007,17 @@ def convert_and_import(
10081007 try :
10091008 # By passing the specified VAE to the conversion function, the autoencoder
10101009 # will be built into the model rather than tacked on afterward via the config file
1011- vae_model = self ._load_vae (vae ) if vae else None
1010+ vae_model = None
1011+ if vae :
1012+ vae_model = self ._load_vae (vae )
1013+ vae_path = None
10121014 convert_ckpt_to_diffusers (
10131015 ckpt_path ,
10141016 diffusers_path ,
10151017 extract_ema = True ,
10161018 original_config_file = original_config_file ,
10171019 vae = vae_model ,
1018- vae_path = str ( vae_path ) if vae_path else None ,
1020+ vae_path = vae_path ,
10191021 scan_needed = scan_needed ,
10201022 )
10211023 print (
@@ -1062,36 +1064,6 @@ def search_models(self, search_folder):
10621064
10631065 return search_folder , found_models
10641066
1065- def _choose_diffusers_vae (
1066- self , model_name : str , vae : str = None
1067- ) -> Union [dict , str ]:
1068- # In the event that the original entry is using a custom ckpt VAE, we try to
1069- # map that VAE onto a diffuser VAE using a hard-coded dictionary.
1070- # I would prefer to do this differently: We load the ckpt model into memory, swap the
1071- # VAE in memory, and then pass that to convert_ckpt_to_diffusers() so that the swapped
1072- # VAE is built into the model. However, when I tried this I got obscure key errors.
1073- if vae :
1074- return vae
1075- if model_name in self .config and (
1076- vae_ckpt_path := self .model_info (model_name ).get ("vae" , None )
1077- ):
1078- vae_basename = Path (vae_ckpt_path ).stem
1079- diffusers_vae = None
1080- if diffusers_vae := VAE_TO_REPO_ID .get (vae_basename , None ):
1081- print (
1082- f">> { vae_basename } VAE corresponds to known { diffusers_vae } diffusers version"
1083- )
1084- vae = {"repo_id" : diffusers_vae }
1085- else :
1086- print (
1087- f'** Custom VAE "{ vae_basename } " found, but corresponding diffusers model unknown'
1088- )
1089- print (
1090- '** Using "stabilityai/sd-vae-ft-mse"; If this isn\' t right, please edit the model config'
1091- )
1092- vae = {"repo_id" : "stabilityai/sd-vae-ft-mse" }
1093- return vae
1094-
10951067 def _make_cache_room (self ) -> None :
10961068 num_loaded_models = len (self .models )
10971069 if num_loaded_models >= self .max_loaded_models :
0 commit comments