1919from enum import Enum
2020from pathlib import Path
2121from shutil import move , rmtree
22- from typing import Any , Callable , Optional , Union
22+ from typing import Any , Callable , Optional , Union , List
2323
2424import safetensors
2525import safetensors .torch
@@ -368,11 +368,19 @@ def _load_ckpt_model(self, model_name, mconfig):
368368 # check whether this is a v2 file and force conversion
369369 convert = Globals .ckpt_convert or self .is_v2_config (config )
370370
371+ if matching_config := self ._scan_for_matching_file (Path (weights ),suffixes = ['.yaml' ]):
372+ print (f' | Using external config file { matching_config } ' )
373+ config = matching_config
374+
371375 # get the path to the custom vae, if any
372376 vae_path = None
377+ # first we use whatever is in the config file
373378 if vae :
374379 path = Path (vae if os .path .isabs (vae ) else os .path .normpath (os .path .join (Globals .root , vae )))
375- vae_path = path if path .exists () else None
380+ if path .exists ():
381+ vae_path = path
382+ # then we look for a file with the same basename
383+ vae_path = vae_path or self ._scan_for_matching_file (Path (weights ))
376384
377385 # if converting automatically to diffusers, then we do the conversion and return
378386 # a diffusers pipeline
@@ -449,7 +457,7 @@ def _load_ckpt_model(self, model_name, mconfig):
449457
450458 # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
451459 if vae_path :
452- print (f" | Loading VAE weights from: { vae } " )
460+ print (f" | Loading VAE weights from: { vae_path } " )
453461 if vae_path .suffix in [".ckpt" , ".pt" ]:
454462 self .scan_model (vae_path .name , vae_path )
455463 vae_ckpt = torch .load (vae_path , map_location = "cpu" )
@@ -458,7 +466,7 @@ def _load_ckpt_model(self, model_name, mconfig):
458466 vae_dict = {k : v for k , v in vae_ckpt ["state_dict" ].items () if k [0 :4 ] != "loss" }
459467 model .first_stage_model .load_state_dict (vae_dict , strict = False )
460468 else :
461- print (f " | VAE file { vae } not found. Skipping ." )
469+ print (" | Using VAE built into model ." )
462470
463471 model .to (self .device )
464472 # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@@ -915,12 +923,9 @@ def heuristic_import(
915923 convert = True
916924 print (" | This SD-v2 model will be converted to diffusers format for use" )
917925
918- # look for a custom vae
919- vae_path = None
920- for suffix in ["pt" , "ckpt" , "safetensors" ]:
921- if (model_path .with_suffix (f".vae.{ suffix } " )).exists ():
922- vae_path = model_path .with_suffix (f".vae.{ suffix } " )
923- print (f" | Using VAE file { vae_path .name } " )
926+ if (vae_path := self ._scan_for_matching_file (model_path )):
927+ print (f" | Using VAE file { vae_path .name } " )
928+
924929 if convert :
925930 diffuser_path = Path (
926931 Globals .root , "models" , Globals .converted_ckpts_dir , model_path .stem
@@ -1316,6 +1321,22 @@ def _cached_sha256(self, path, data) -> Union[str, bytes]:
13161321 f .write (hash )
13171322 return hash
13181323
1324+ @classmethod
1325+ def _scan_for_matching_file (
1326+ self ,model_path : Path ,
1327+ suffixes : List [str ]= ['.vae.pt' ,'.vae.ckpt' ,'.vae.safetensors' ]
1328+ )-> Path :
1329+ """
1330+ Find a file with same basename as the indicated model, but with one
1331+ of the suffixes passed.
1332+ """
1333+ # look for a custom vae
1334+ vae_path = None
1335+ for suffix in suffixes :
1336+ if model_path .with_suffix (suffix ).exists ():
1337+ vae_path = model_path .with_suffix (suffix )
1338+ return vae_path
1339+
13191340 def _load_vae (self , vae_config ) -> AutoencoderKL :
13201341 vae_args = {}
13211342 try :
0 commit comments