@@ -609,15 +609,35 @@ def configure_optimizers(self):
609609 opt = torch .optim .AdamW (params , lr = lr )
610610 return opt
611611
612+
612613
613614class LatentDiffusion (DDPM ):
614615 """main class"""
615616
617+ @staticmethod
618+ def _fallback_personalization_config ()-> dict :
619+ """
620+ This protects us against custom legacy config files that
621+ don't contain the personalization_config section.
622+ """
623+ return OmegaConf .create (
624+ dict (
625+ target = 'ldm.modules.embedding_manager.EmbeddingManager' ,
626+ params = dict (
627+ placeholder_strings = list ('*' ),
628+ initializer_words = list ('sculpture' ),
629+ per_image_tokens = False ,
630+ num_vectors_per_token = 1 ,
631+ progressive_words = False ,
632+ )
633+ )
634+ )
635+
616636 def __init__ (
617637 self ,
618638 first_stage_config ,
619639 cond_stage_config ,
620- personalization_config = None ,
640+ personalization_config = _fallback_personalization_config () ,
621641 num_timesteps_cond = None ,
622642 cond_stage_key = 'image' ,
623643 cond_stage_trainable = False ,
@@ -676,8 +696,6 @@ def __init__(
676696 for param in self .model .parameters ():
677697 param .requires_grad = False
678698
679- personalization_config = personalization_config or self ._fallback_personalization_config ()
680-
681699 self .embedding_manager = self .instantiate_embedding_manager (
682700 personalization_config , self .cond_stage_model
683701 )
@@ -802,24 +820,6 @@ def instantiate_embedding_manager(self, config, embedder):
802820
803821 return model
804822
805- def _fallback_personalization_config (self )-> dict :
806- """
807- This protects us against custom legacy config files that
808- don't contain the personalization_config section.
809- """
810- return OmegaConf .create (
811- dict (
812- target = 'ldm.modules.embedding_manager.EmbeddingManager' ,
813- params = dict (
814- placeholder_strings = list ('*' ),
815- initializer_words = list ('sculpture' ),
816- per_image_tokens = False ,
817- num_vectors_per_token = 1 ,
818- progressive_words = False ,
819- )
820- )
821- )
822-
823823 def _get_denoise_row_from_list (
824824 self , samples , desc = '' , force_no_decoder_quantization = False
825825 ):
0 commit comments