Skip to content

Commit ce22a15

Browse files
committed
convert VAEs to diffusers format automatically
- If the user enters a VAE .ckpt path into the VAE field of a diffusers model, the VAE will be automatically converted behind the scenes into a diffusers version, then loaded. - This commit is untested (done on an airplane).
1 parent 298ccda commit ce22a15

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

ldm/invoke/ckpt_to_diffuser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
620620
for key in keys:
621621
if key.startswith(vae_key):
622622
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
623+
new_checkpoint = _convert_ldm_vae_checkpoint(vae_state_dict,config)
624+
return new_checkpoint
623625

626+
def _convert_ldm_vae_checkpoint(vae_state_dict, config):
624627
new_checkpoint = {}
625628

626629
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]

ldm/invoke/model_manager.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import contextlib
1010
import gc
1111
import hashlib
12-
import io
1312
import os
1413
import re
1514
import sys
@@ -32,10 +31,15 @@
3231
from omegaconf.dictconfig import DictConfig
3332
from picklescan.scanner import scan_file_path
3433

34+
from .ckpt_to_diffuser import (
35+
load_pipeline_from_original_stable_diffusion_ckpt,
36+
create_vae_diffusers_config,
37+
convert_ldm_vae_checkpoint,
38+
)
3539
from ldm.invoke.devices import CPU_DEVICE
3640
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
3741
from ldm.invoke.globals import Globals, global_cache_dir
38-
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
42+
from ldm.util import ask_user, download_with_resume, url_attachment_name
3943

4044

4145
class SDLegacyType(Enum):
@@ -370,14 +374,7 @@ def _load_ckpt_model(self, model_name, mconfig):
370374
print(
371375
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
372376
)
373-
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
374-
375-
# try:
376-
# if self.list_models()[self.current_model]['status'] == 'active':
377-
# self.offload_model(self.current_model)
378-
# except Exception:
379-
# pass
380-
377+
381378
if self._has_cuda():
382379
torch.cuda.empty_cache()
383380
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -1236,6 +1233,13 @@ def _scan_for_matching_file(
12361233
return vae_path
12371234

12381235
def _load_vae(self, vae_config) -> AutoencoderKL:
1236+
1237+
# Handle the common case of a user shoving a VAE .ckpt into
1238+
# the vae field for a diffusers. We convert it into diffusers
1239+
# format and use it.
1240+
if type(vae_config) in [str,Path]:
1241+
return self.convert_vae(vae_config)
1242+
12391243
vae_args = {}
12401244
try:
12411245
name_or_path = self.model_name_or_path(vae_config)
@@ -1283,6 +1287,21 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
12831287

12841288
return vae
12851289

1290+
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
1291+
vae_path = Path(vae_path)
1292+
if vae_path.suffix in ['.pt','.ckpt']:
1293+
vae_state_dict = torch.load(vae_path)
1294+
else:
1295+
vae_state_dict = safetensors.torch.load_file(vae_path)
1296+
# TODO: see if this works with 1.x inpaint models and 2.x models
1297+
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
1298+
original_conf = OmegaConf.load(config_file_path)
1299+
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
1300+
diffusers_vae = convert_ldm_vae_checkpoint(vae_state_dict,vae_config)
1301+
vae = AutoencoderKL(**vae_config)
1302+
vae.load_state_dict(diffusers_vae)
1303+
return vae
1304+
12861305
@staticmethod
12871306
def _delete_model_from_cache(repo_id):
12881307
cache_info = scan_cache_dir(global_cache_dir("diffusers"))

0 commit comments

Comments
 (0)