@@ -614,30 +614,11 @@ def configure_optimizers(self):
614614class LatentDiffusion (DDPM ):
615615 """main class"""
616616
617- @staticmethod
618- def _fallback_personalization_config ()-> dict :
619- """
620- This protects us against custom legacy config files that
621- don't contain the personalization_config section.
622- """
623- return OmegaConf .create (
624- dict (
625- target = 'ldm.modules.embedding_manager.EmbeddingManager' ,
626- params = dict (
627- placeholder_strings = list ('*' ),
628- initializer_words = list ('sculpture' ),
629- per_image_tokens = False ,
630- num_vectors_per_token = 1 ,
631- progressive_words = False ,
632- )
633- )
634- )
635-
636617 def __init__ (
637618 self ,
638619 first_stage_config ,
639620 cond_stage_config ,
640- personalization_config = _fallback_personalization_config () ,
621+ personalization_config = None ,
641622 num_timesteps_cond = None ,
642623 cond_stage_key = 'image' ,
643624 cond_stage_trainable = False ,
@@ -695,7 +676,8 @@ def __init__(
695676 self .model .train = disabled_train
696677 for param in self .model .parameters ():
697678 param .requires_grad = False
698-
679+
680+ personalization_config = personalization_config or self ._fallback_personalization_config ()
699681 self .embedding_manager = self .instantiate_embedding_manager (
700682 personalization_config , self .cond_stage_model
701683 )
@@ -2170,6 +2152,25 @@ def on_save_checkpoint(self, checkpoint):
21702152
21712153 self .emb_ckpt_counter += 500
21722154
2155+ @classmethod
2156+ def _fallback_personalization_config (self )-> dict :
2157+ """
2158+ This protects us against custom legacy config files that
2159+ don't contain the personalization_config section.
2160+ """
2161+ return OmegaConf .create (
2162+ dict (
2163+ target = 'ldm.modules.embedding_manager.EmbeddingManager' ,
2164+ params = dict (
2165+ placeholder_strings = list ('*' ),
2166+ initializer_words = list ('sculpture' ),
2167+ per_image_tokens = False ,
2168+ num_vectors_per_token = 1 ,
2169+ progressive_words = False ,
2170+ )
2171+ )
2172+ )
2173+
21732174
21742175class DiffusionWrapper (pl .LightningModule ):
21752176 def __init__ (self , diff_model_config , conditioning_key ):
0 commit comments