|
9 | 9 | import contextlib |
10 | 10 | import gc |
11 | 11 | import hashlib |
12 | | -import io |
13 | 12 | import os |
14 | 13 | import re |
15 | 14 | import sys |
|
32 | 31 | from omegaconf.dictconfig import DictConfig |
33 | 32 | from picklescan.scanner import scan_file_path |
34 | 33 |
|
| 34 | +from .ckpt_to_diffuser import ( |
| 35 | + load_pipeline_from_original_stable_diffusion_ckpt, |
| 36 | + create_vae_diffusers_config, |
| 37 | + convert_ldm_vae_checkpoint, |
| 38 | + ) |
35 | 39 | from ldm.invoke.devices import CPU_DEVICE |
36 | 40 | from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline |
37 | 41 | from ldm.invoke.globals import Globals, global_cache_dir |
38 | | -from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name |
| 42 | +from ldm.util import ask_user, download_with_resume, url_attachment_name |
39 | 43 |
|
40 | 44 |
|
41 | 45 | class SDLegacyType(Enum): |
@@ -370,14 +374,7 @@ def _load_ckpt_model(self, model_name, mconfig): |
370 | 374 | print( |
371 | 375 | f">> Converting legacy checkpoint {model_name} into a diffusers model..." |
372 | 376 | ) |
373 | | - from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt |
374 | | - |
375 | | -# try: |
376 | | -# if self.list_models()[self.current_model]['status'] == 'active': |
377 | | -# self.offload_model(self.current_model) |
378 | | -# except Exception: |
379 | | -# pass |
380 | | - |
| 377 | + |
381 | 378 | if self._has_cuda(): |
382 | 379 | torch.cuda.empty_cache() |
383 | 380 | pipeline = load_pipeline_from_original_stable_diffusion_ckpt( |
@@ -1236,6 +1233,13 @@ def _scan_for_matching_file( |
1236 | 1233 | return vae_path |
1237 | 1234 |
|
1238 | 1235 | def _load_vae(self, vae_config) -> AutoencoderKL: |
| 1236 | + |
| 1237 | + # Handle the common case of a user shoving a VAE .ckpt into |
| 1238 | + # the vae field for a diffusers. We convert it into diffusers |
| 1239 | + # format and use it. |
| 1240 | + if type(vae_config) in [str,Path]: |
| 1241 | + return self.convert_vae(vae_config) |
| 1242 | + |
1239 | 1243 | vae_args = {} |
1240 | 1244 | try: |
1241 | 1245 | name_or_path = self.model_name_or_path(vae_config) |
@@ -1283,6 +1287,21 @@ def _load_vae(self, vae_config) -> AutoencoderKL: |
1283 | 1287 |
|
1284 | 1288 | return vae |
1285 | 1289 |
|
| 1290 | + def convert_vae(vae_path: Union[Path,str])->AutoencoderKL: |
| 1291 | + vae_path = Path(vae_path) |
| 1292 | + if vae_path.suffix in ['.pt','.ckpt']: |
| 1293 | + vae_state_dict = torch.load(vae_path) |
| 1294 | + else: |
| 1295 | + vae_state_dict = safetensors.torch.load_file(vae_path) |
| 1296 | + # TODO: see if this works with 1.x inpaint models and 2.x models |
| 1297 | + config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml") |
| 1298 | + original_conf = OmegaConf.load(config_file_path) |
| 1299 | + vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix |
| 1300 | + diffusers_vae = convert_ldm_vae_checkpoint(vae_state_dict,vae_config) |
| 1301 | + vae = AutoencoderKL(**vae_config) |
| 1302 | + vae.load_state_dict(diffusers_vae) |
| 1303 | + return vae |
| 1304 | + |
1286 | 1305 | @staticmethod |
1287 | 1306 | def _delete_model_from_cache(repo_id): |
1288 | 1307 | cache_info = scan_cache_dir(global_cache_dir("diffusers")) |
|
0 commit comments