Skip to content

Commit 3d4f4b6

Browse files
committed
support external legacy config files with no personalization section
1 parent 249173f commit 3d4f4b6

File tree

3 files changed

+57
-16
lines changed

3 files changed

+57
-16
lines changed

ldm/invoke/model_manager.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from enum import Enum
2020
from pathlib import Path
2121
from shutil import move, rmtree
22-
from typing import Any, Callable, Optional, Union
22+
from typing import Any, Callable, Optional, Union, List
2323

2424
import safetensors
2525
import safetensors.torch
@@ -368,11 +368,19 @@ def _load_ckpt_model(self, model_name, mconfig):
368368
# check whether this is a v2 file and force conversion
369369
convert = Globals.ckpt_convert or self.is_v2_config(config)
370370

371+
if matching_config := self._scan_for_matching_file(Path(weights),suffixes=['.yaml']):
372+
print(f' | Using external config file {matching_config}')
373+
config = matching_config
374+
371375
# get the path to the custom vae, if any
372376
vae_path = None
377+
# first we use whatever is in the config file
373378
if vae:
374379
path = Path(vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae)))
375-
vae_path = path if path.exists() else None
380+
if path.exists():
381+
vae_path = path
382+
# then we look for a file with the same basename
383+
vae_path = vae_path or self._scan_for_matching_file(Path(weights))
376384

377385
# if converting automatically to diffusers, then we do the conversion and return
378386
# a diffusers pipeline
@@ -449,7 +457,7 @@ def _load_ckpt_model(self, model_name, mconfig):
449457

450458
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
451459
if vae_path:
452-
print(f" | Loading VAE weights from: {vae}")
460+
print(f" | Loading VAE weights from: {vae_path}")
453461
if vae_path.suffix in [".ckpt", ".pt"]:
454462
self.scan_model(vae_path.name, vae_path)
455463
vae_ckpt = torch.load(vae_path, map_location="cpu")
@@ -458,7 +466,7 @@ def _load_ckpt_model(self, model_name, mconfig):
458466
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
459467
model.first_stage_model.load_state_dict(vae_dict, strict=False)
460468
else:
461-
print(f" | VAE file {vae} not found. Skipping.")
469+
print(" | Using VAE built into model.")
462470

463471
model.to(self.device)
464472
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@@ -915,12 +923,9 @@ def heuristic_import(
915923
convert = True
916924
print(" | This SD-v2 model will be converted to diffusers format for use")
917925

918-
# look for a custom vae
919-
vae_path = None
920-
for suffix in ["pt", "ckpt", "safetensors"]:
921-
if (model_path.with_suffix(f".vae.{suffix}")).exists():
922-
vae_path = model_path.with_suffix(f".vae.{suffix}")
923-
print(f" | Using VAE file {vae_path.name}")
926+
if (vae_path := self._scan_for_matching_file(model_path)):
927+
print(f" | Using VAE file {vae_path.name}")
928+
924929
if convert:
925930
diffuser_path = Path(
926931
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
@@ -1316,6 +1321,22 @@ def _cached_sha256(self, path, data) -> Union[str, bytes]:
13161321
f.write(hash)
13171322
return hash
13181323

1324+
@classmethod
1325+
def _scan_for_matching_file(
1326+
self,model_path: Path,
1327+
suffixes: List[str]=['.vae.pt','.vae.ckpt','.vae.safetensors']
1328+
)->Path:
1329+
"""
1330+
Find a file with same basename as the indicated model, but with one
1331+
of the suffixes passed.
1332+
"""
1333+
# look for a custom vae
1334+
vae_path = None
1335+
for suffix in suffixes:
1336+
if model_path.with_suffix(suffix).exists():
1337+
vae_path = model_path.with_suffix(suffix)
1338+
return vae_path
1339+
13191340
def _load_vae(self, vae_config) -> AutoencoderKL:
13201341
vae_args = {}
13211342
try:

ldm/models/diffusion/ddpm.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020
from torchvision.utils import make_grid
2121
from pytorch_lightning.utilities.distributed import rank_zero_only
22-
from omegaconf import ListConfig
22+
from omegaconf import ListConfig, OmegaConf
2323
import urllib
2424

2525
from ldm.modules.textual_inversion_manager import TextualInversionManager
@@ -617,7 +617,7 @@ def __init__(
617617
self,
618618
first_stage_config,
619619
cond_stage_config,
620-
personalization_config,
620+
personalization_config=None,
621621
num_timesteps_cond=None,
622622
cond_stage_key='image',
623623
cond_stage_trainable=False,
@@ -676,6 +676,8 @@ def __init__(
676676
for param in self.model.parameters():
677677
param.requires_grad = False
678678

679+
personalization_config = personalization_config or self._fallback_personalization_config()
680+
679681
self.embedding_manager = self.instantiate_embedding_manager(
680682
personalization_config, self.cond_stage_model
681683
)
@@ -800,6 +802,24 @@ def instantiate_embedding_manager(self, config, embedder):
800802

801803
return model
802804

805+
def _fallback_personalization_config(self)->dict:
806+
"""
807+
This protects us against custom legacy config files that
808+
don't contain the personalization_config section.
809+
"""
810+
return OmegaConf.create(
811+
dict(
812+
target='ldm.modules.embedding_manager.EmbeddingManager',
813+
params=dict(
814+
placeholder_strings=list('*'),
815+
initializer_words=list('sculpture'),
816+
per_image_tokens=False,
817+
num_vectors_per_token=1,
818+
progressive_words=False,
819+
)
820+
)
821+
)
822+
803823
def _get_denoise_row_from_list(
804824
self, samples, desc='', force_no_decoder_quantization=False
805825
):

ldm/modules/encoders/modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ def forward(self, text, **kwargs):
463463
def encode(self, text, **kwargs):
464464
return self(text, **kwargs)
465465

466+
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
467+
# TODO all of the weighting and expanding stuff needs be moved out of this class
468+
self.textual_inversion_manager = manager
469+
466470
@property
467471
def device(self):
468472
return self.transformer.device
@@ -476,10 +480,6 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
476480
fragment_weights_key = "fragment_weights"
477481
return_tokens_key = "return_tokens"
478482

479-
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
480-
# TODO all of the weighting and expanding stuff needs be moved out of this class
481-
self.textual_inversion_manager = manager
482-
483483
def forward(self, text: list, **kwargs):
484484
# TODO all of the weighting and expanding stuff needs be moved out of this class
485485
'''

0 commit comments

Comments
 (0)