|
9 | 9 | import contextlib |
10 | 10 | import gc |
11 | 11 | import hashlib |
12 | | -import io |
13 | 12 | import os |
14 | 13 | import re |
15 | 14 | import sys |
|
31 | 30 | from omegaconf import OmegaConf |
32 | 31 | from omegaconf.dictconfig import DictConfig |
33 | 32 | from picklescan.scanner import scan_file_path |
34 | | - |
35 | 33 | from ldm.invoke.devices import CPU_DEVICE |
36 | 34 | from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline |
37 | 35 | from ldm.invoke.globals import Globals, global_cache_dir |
38 | | -from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name |
| 36 | +from ldm.util import ask_user, download_with_resume, url_attachment_name |
39 | 37 |
|
40 | 38 |
|
41 | 39 | class SDLegacyType(Enum): |
@@ -370,8 +368,9 @@ def _load_ckpt_model(self, model_name, mconfig): |
370 | 368 | print( |
371 | 369 | f">> Converting legacy checkpoint {model_name} into a diffusers model..." |
372 | 370 | ) |
373 | | - from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt |
374 | | - |
| 371 | + from .ckpt_to_diffuser import ( |
| 372 | + load_pipeline_from_original_stable_diffusion_ckpt, |
| 373 | + ) |
375 | 374 | if self._has_cuda(): |
376 | 375 | torch.cuda.empty_cache() |
377 | 376 | pipeline = load_pipeline_from_original_stable_diffusion_ckpt( |
@@ -1230,6 +1229,13 @@ def _scan_for_matching_file( |
1230 | 1229 | return vae_path |
1231 | 1230 |
|
1232 | 1231 | def _load_vae(self, vae_config) -> AutoencoderKL: |
| 1232 | + |
| 1233 | + # Handle the common case of a user shoving a VAE .ckpt into |
| 1234 | + # the vae field for a diffusers. We convert it into diffusers |
| 1235 | + # format and use it. |
| 1236 | + if type(vae_config) in [str,Path]: |
| 1237 | + return self.convert_vae(vae_config) |
| 1238 | + |
1233 | 1239 | vae_args = {} |
1234 | 1240 | try: |
1235 | 1241 | name_or_path = self.model_name_or_path(vae_config) |
@@ -1277,6 +1283,32 @@ def _load_vae(self, vae_config) -> AutoencoderKL: |
1277 | 1283 |
|
1278 | 1284 | return vae |
1279 | 1285 |
|
| 1286 | + @staticmethod |
| 1287 | + def convert_vae(vae_path: Union[Path,str])->AutoencoderKL: |
| 1288 | + print(f" | A checkpoint VAE was detected. Converting to diffusers format.") |
| 1289 | + vae_path = Path(Globals.root,vae_path).resolve() |
| 1290 | + |
| 1291 | + from .ckpt_to_diffuser import ( |
| 1292 | + create_vae_diffusers_config, |
| 1293 | + convert_ldm_vae_state_dict, |
| 1294 | + ) |
| 1295 | + |
| 1296 | + vae_path = Path(vae_path) |
| 1297 | + if vae_path.suffix in ['.pt','.ckpt']: |
| 1298 | + vae_state_dict = torch.load(vae_path, map_location="cpu") |
| 1299 | + else: |
| 1300 | + vae_state_dict = safetensors.torch.load_file(vae_path) |
| 1301 | + if 'state_dict' in vae_state_dict: |
| 1302 | + vae_state_dict = vae_state_dict['state_dict'] |
| 1303 | + # TODO: see if this works with 1.x inpaint models and 2.x models |
| 1304 | + config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml") |
| 1305 | + original_conf = OmegaConf.load(config_file_path) |
| 1306 | + vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix |
| 1307 | + diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config) |
| 1308 | + vae = AutoencoderKL(**vae_config) |
| 1309 | + vae.load_state_dict(diffusers_vae) |
| 1310 | + return vae |
| 1311 | + |
1280 | 1312 | @staticmethod |
1281 | 1313 | def _delete_model_from_cache(repo_id): |
1282 | 1314 | cache_info = scan_cache_dir(global_cache_dir("diffusers")) |
|
0 commit comments