Skip to content

Commit 4145e27

Browse files
committed
move personalization fallback section into a static method
1 parent 3d4f4b6 commit 4145e27

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

ldm/models/diffusion/ddpm.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -609,15 +609,35 @@ def configure_optimizers(self):
609609
opt = torch.optim.AdamW(params, lr=lr)
610610
return opt
611611

612+
612613

613614
class LatentDiffusion(DDPM):
614615
"""main class"""
615616

617+
@staticmethod
618+
def _fallback_personalization_config()->dict:
619+
"""
620+
This protects us against custom legacy config files that
621+
don't contain the personalization_config section.
622+
"""
623+
return OmegaConf.create(
624+
dict(
625+
target='ldm.modules.embedding_manager.EmbeddingManager',
626+
params=dict(
627+
placeholder_strings=list('*'),
628+
initializer_words=list('sculpture'),
629+
per_image_tokens=False,
630+
num_vectors_per_token=1,
631+
progressive_words=False,
632+
)
633+
)
634+
)
635+
616636
def __init__(
617637
self,
618638
first_stage_config,
619639
cond_stage_config,
620-
personalization_config=None,
640+
personalization_config=_fallback_personalization_config(),
621641
num_timesteps_cond=None,
622642
cond_stage_key='image',
623643
cond_stage_trainable=False,
@@ -676,8 +696,6 @@ def __init__(
676696
for param in self.model.parameters():
677697
param.requires_grad = False
678698

679-
personalization_config = personalization_config or self._fallback_personalization_config()
680-
681699
self.embedding_manager = self.instantiate_embedding_manager(
682700
personalization_config, self.cond_stage_model
683701
)
@@ -802,24 +820,6 @@ def instantiate_embedding_manager(self, config, embedder):
802820

803821
return model
804822

805-
def _fallback_personalization_config(self)->dict:
806-
"""
807-
This protects us against custom legacy config files that
808-
don't contain the personalization_config section.
809-
"""
810-
return OmegaConf.create(
811-
dict(
812-
target='ldm.modules.embedding_manager.EmbeddingManager',
813-
params=dict(
814-
placeholder_strings=list('*'),
815-
initializer_words=list('sculpture'),
816-
per_image_tokens=False,
817-
num_vectors_per_token=1,
818-
progressive_words=False,
819-
)
820-
)
821-
)
822-
823823
def _get_denoise_row_from_list(
824824
self, samples, desc='', force_no_decoder_quantization=False
825825
):

0 commit comments

Comments
 (0)