Skip to content

Commit 0ce628b

Browse files
autoconvert legacy VAEs (#3235)
This draft PR implements a system in which if a diffusers model is loaded, and the model manager detects that the user tried to assign a legacy checkpoint VAE to the model, the checkpoint will be converted to a diffusers VAE in RAM. It is draft because it has not been carefully tested yet, and there are some edge cases that are not handled properly.
2 parents 53f5dfb + ddcf9a3 commit 0ce628b

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

ldm/invoke/ckpt_to_diffuser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
620620
for key in keys:
621621
if key.startswith(vae_key):
622622
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
623+
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
624+
return new_checkpoint
623625

626+
def convert_ldm_vae_state_dict(vae_state_dict, config):
624627
new_checkpoint = {}
625628

626629
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]

ldm/invoke/model_manager.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import contextlib
1010
import gc
1111
import hashlib
12-
import io
1312
import os
1413
import re
1514
import sys
@@ -31,11 +30,10 @@
3130
from omegaconf import OmegaConf
3231
from omegaconf.dictconfig import DictConfig
3332
from picklescan.scanner import scan_file_path
34-
3533
from ldm.invoke.devices import CPU_DEVICE
3634
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
3735
from ldm.invoke.globals import Globals, global_cache_dir
38-
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
36+
from ldm.util import ask_user, download_with_resume, url_attachment_name
3937

4038

4139
class SDLegacyType(Enum):
@@ -370,8 +368,9 @@ def _load_ckpt_model(self, model_name, mconfig):
370368
print(
371369
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
372370
)
373-
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
374-
371+
from .ckpt_to_diffuser import (
372+
load_pipeline_from_original_stable_diffusion_ckpt,
373+
)
375374
if self._has_cuda():
376375
torch.cuda.empty_cache()
377376
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -1230,6 +1229,13 @@ def _scan_for_matching_file(
12301229
return vae_path
12311230

12321231
def _load_vae(self, vae_config) -> AutoencoderKL:
1232+
1233+
# Handle the common case of a user shoving a VAE .ckpt into
1234+
# the vae field for a diffusers. We convert it into diffusers
1235+
# format and use it.
1236+
if type(vae_config) in [str,Path]:
1237+
return self.convert_vae(vae_config)
1238+
12331239
vae_args = {}
12341240
try:
12351241
name_or_path = self.model_name_or_path(vae_config)
@@ -1277,6 +1283,32 @@ def _load_vae(self, vae_config) -> AutoencoderKL:
12771283

12781284
return vae
12791285

1286+
@staticmethod
1287+
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
1288+
print(f" | A checkpoint VAE was detected. Converting to diffusers format.")
1289+
vae_path = Path(Globals.root,vae_path).resolve()
1290+
1291+
from .ckpt_to_diffuser import (
1292+
create_vae_diffusers_config,
1293+
convert_ldm_vae_state_dict,
1294+
)
1295+
1296+
vae_path = Path(vae_path)
1297+
if vae_path.suffix in ['.pt','.ckpt']:
1298+
vae_state_dict = torch.load(vae_path, map_location="cpu")
1299+
else:
1300+
vae_state_dict = safetensors.torch.load_file(vae_path)
1301+
if 'state_dict' in vae_state_dict:
1302+
vae_state_dict = vae_state_dict['state_dict']
1303+
# TODO: see if this works with 1.x inpaint models and 2.x models
1304+
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
1305+
original_conf = OmegaConf.load(config_file_path)
1306+
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
1307+
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
1308+
vae = AutoencoderKL(**vae_config)
1309+
vae.load_state_dict(diffusers_vae)
1310+
return vae
1311+
12801312
@staticmethod
12811313
def _delete_model_from_cache(repo_id):
12821314
cache_info = scan_cache_dir(global_cache_dir("diffusers"))

0 commit comments

Comments
 (0)