Skip to content

Commit 23d9361

Browse files
committed
autoconvert ckpt VAEs assigned to diffusers models
1 parent ce22a15 commit 23d9361

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

ldm/invoke/ckpt_to_diffuser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
620620
for key in keys:
621621
if key.startswith(vae_key):
622622
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
623-
new_checkpoint = _convert_ldm_vae_checkpoint(vae_state_dict,config)
623+
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
624624
return new_checkpoint
625625

626-
def _convert_ldm_vae_checkpoint(vae_state_dict, config):
626+
def convert_ldm_vae_state_dict(vae_state_dict, config):
627627
new_checkpoint = {}
628628

629629
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]

ldm/invoke/model_manager.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@
3030
from omegaconf import OmegaConf
3131
from omegaconf.dictconfig import DictConfig
3232
from picklescan.scanner import scan_file_path
33-
34-
from .ckpt_to_diffuser import (
35-
load_pipeline_from_original_stable_diffusion_ckpt,
36-
create_vae_diffusers_config,
37-
convert_ldm_vae_checkpoint,
38-
)
3933
from ldm.invoke.devices import CPU_DEVICE
4034
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
4135
from ldm.invoke.globals import Globals, global_cache_dir
@@ -374,7 +368,10 @@ def _load_ckpt_model(self, model_name, mconfig):
374368
print(
375369
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
376370
)
377-
371+
from .ckpt_to_diffuser import (
372+
load_pipeline_from_original_stable_diffusion_ckpt,
373+
)
374+
378375
if self._has_cuda():
379376
torch.cuda.empty_cache()
380377
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -1287,17 +1284,28 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
12871284

12881285
return vae
12891286

1287+
@staticmethod
12901288
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
1289+
print(f" | A checkpoint VAE was detected. Converting to diffusers format.")
1290+
vae_path = Path(Globals.root,vae_path).resolve()
1291+
1292+
from .ckpt_to_diffuser import (
1293+
create_vae_diffusers_config,
1294+
convert_ldm_vae_state_dict,
1295+
)
1296+
12911297
vae_path = Path(vae_path)
12921298
if vae_path.suffix in ['.pt','.ckpt']:
1293-
vae_state_dict = torch.load(vae_path)
1299+
vae_state_dict = torch.load(vae_path, map_location="cpu")
12941300
else:
12951301
vae_state_dict = safetensors.torch.load_file(vae_path)
1302+
if 'state_dict' in vae_state_dict:
1303+
vae_state_dict = vae_state_dict['state_dict']
12961304
# TODO: see if this works with 1.x inpaint models and 2.x models
12971305
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
12981306
original_conf = OmegaConf.load(config_file_path)
12991307
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
1300-
diffusers_vae = convert_ldm_vae_checkpoint(vae_state_dict,vae_config)
1308+
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
13011309
vae = AutoencoderKL(**vae_config)
13021310
vae.load_state_dict(diffusers_vae)
13031311
return vae

0 commit comments

Comments
 (0)