@@ -135,7 +135,7 @@ def log_validation(
135135 pipeline = DiffusionPipeline .from_pretrained (
136136 args .pretrained_model_name_or_path ,
137137 text_encoder = accelerator .unwrap_model (text_encoder_1 ),
138- text_encoder_2 = text_encoder_2 ,
138+ text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
139139 tokenizer = tokenizer_1 ,
140140 tokenizer_2 = tokenizer_2 ,
141141 unet = unet ,
@@ -678,36 +678,54 @@ def main():
678678 f"The tokenizer already contains the token { args .placeholder_token } . Please pass a different"
679679 " `placeholder_token` that is not already in the tokenizer."
680680 )
681+ num_added_tokens = tokenizer_2 .add_tokens (placeholder_tokens )
682+ if num_added_tokens != args .num_vectors :
683+ raise ValueError (
684+ f"The 2nd tokenizer already contains the token { args .placeholder_token } . Please pass a different"
685+ " `placeholder_token` that is not already in the tokenizer."
686+ )
681687
682688 # Convert the initializer_token, placeholder_token to ids
683689 token_ids = tokenizer_1 .encode (args .initializer_token , add_special_tokens = False )
690+ token_ids_2 = tokenizer_2 .encode (args .initializer_token , add_special_tokens = False )
691+
684692 # Check if initializer_token is a single token or a sequence of tokens
685- if len (token_ids ) > 1 :
693+ if len (token_ids ) > 1 or len ( token_ids_2 ) > 1 :
686694 raise ValueError ("The initializer token must be a single token." )
687695
688696 initializer_token_id = token_ids [0 ]
689697 placeholder_token_ids = tokenizer_1 .convert_tokens_to_ids (placeholder_tokens )
698+ initializer_token_id_2 = token_ids_2 [0 ]
699+ placeholder_token_ids_2 = tokenizer_2 .convert_tokens_to_ids (placeholder_tokens )
690700
691701 # Resize the token embeddings as we are adding new special tokens to the tokenizer
692702 text_encoder_1 .resize_token_embeddings (len (tokenizer_1 ))
703+ text_encoder_2 .resize_token_embeddings (len (tokenizer_2 ))
693704
694705 # Initialise the newly added placeholder token with the embeddings of the initializer token
695706 token_embeds = text_encoder_1 .get_input_embeddings ().weight .data
707+ token_embeds_2 = text_encoder_2 .get_input_embeddings ().weight .data
696708 with torch .no_grad ():
697709 for token_id in placeholder_token_ids :
698710 token_embeds [token_id ] = token_embeds [initializer_token_id ].clone ()
711+ for token_id in placeholder_token_ids_2 :
712+ token_embeds_2 [token_id ] = token_embeds_2 [initializer_token_id_2 ].clone ()
699713
700714 # Freeze vae and unet
701715 vae .requires_grad_ (False )
702716 unet .requires_grad_ (False )
703- text_encoder_2 . requires_grad_ ( False )
717+
704718 # Freeze all parameters except for the token embeddings in text encoder
705719 text_encoder_1 .text_model .encoder .requires_grad_ (False )
706720 text_encoder_1 .text_model .final_layer_norm .requires_grad_ (False )
707721 text_encoder_1 .text_model .embeddings .position_embedding .requires_grad_ (False )
722+ text_encoder_2 .text_model .encoder .requires_grad_ (False )
723+ text_encoder_2 .text_model .final_layer_norm .requires_grad_ (False )
724+ text_encoder_2 .text_model .embeddings .position_embedding .requires_grad_ (False )
708725
709726 if args .gradient_checkpointing :
710727 text_encoder_1 .gradient_checkpointing_enable ()
728+ text_encoder_2 .gradient_checkpointing_enable ()
711729
712730 if args .enable_xformers_memory_efficient_attention :
713731 if is_xformers_available ():
@@ -746,7 +764,11 @@ def main():
746764 optimizer_class = torch .optim .AdamW
747765
748766 optimizer = optimizer_class (
749- text_encoder_1 .get_input_embeddings ().parameters (), # only optimize the embeddings
767+ # only optimize the embeddings
768+ [
769+ text_encoder_1 .text_model .embeddings .token_embedding .weight ,
770+ text_encoder_2 .text_model .embeddings .token_embedding .weight ,
771+ ],
750772 lr = args .learning_rate ,
751773 betas = (args .adam_beta1 , args .adam_beta2 ),
752774 weight_decay = args .adam_weight_decay ,
@@ -786,9 +808,10 @@ def main():
786808 )
787809
788810 text_encoder_1 .train ()
811+ text_encoder_2 .train ()
789812 # Prepare everything with our `accelerator`.
790- text_encoder_1 , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
791- text_encoder_1 , optimizer , train_dataloader , lr_scheduler
813+ text_encoder_1 , text_encoder_2 , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
814+ text_encoder_1 , text_encoder_2 , optimizer , train_dataloader , lr_scheduler
792815 )
793816
794817 # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -866,11 +889,13 @@ def main():
866889
867890 # keep original embeddings as reference
868891 orig_embeds_params = accelerator .unwrap_model (text_encoder_1 ).get_input_embeddings ().weight .data .clone ()
892+ orig_embeds_params_2 = accelerator .unwrap_model (text_encoder_2 ).get_input_embeddings ().weight .data .clone ()
869893
870894 for epoch in range (first_epoch , args .num_train_epochs ):
871895 text_encoder_1 .train ()
896+ text_encoder_2 .train ()
872897 for step , batch in enumerate (train_dataloader ):
873- with accelerator .accumulate (text_encoder_1 ):
898+ with accelerator .accumulate ([ text_encoder_1 , text_encoder_2 ] ):
874899 # Convert images to latent space
875900 latents = vae .encode (batch ["pixel_values" ].to (dtype = weight_dtype )).latent_dist .sample ().detach ()
876901 latents = latents * vae .config .scaling_factor
@@ -892,9 +917,7 @@ def main():
892917 .hidden_states [- 2 ]
893918 .to (dtype = weight_dtype )
894919 )
895- encoder_output_2 = text_encoder_2 (
896- batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True
897- )
920+ encoder_output_2 = text_encoder_2 (batch ["input_ids_2" ], output_hidden_states = True )
898921 encoder_hidden_states_2 = encoder_output_2 .hidden_states [- 2 ].to (dtype = weight_dtype )
899922 original_size = [
900923 (batch ["original_size" ][0 ][i ].item (), batch ["original_size" ][1 ][i ].item ())
@@ -938,11 +961,16 @@ def main():
938961 # Let's make sure we don't update any embedding weights besides the newly added token
939962 index_no_updates = torch .ones ((len (tokenizer_1 ),), dtype = torch .bool )
940963 index_no_updates [min (placeholder_token_ids ) : max (placeholder_token_ids ) + 1 ] = False
964+ index_no_updates_2 = torch .ones ((len (tokenizer_2 ),), dtype = torch .bool )
965+ index_no_updates_2 [min (placeholder_token_ids_2 ) : max (placeholder_token_ids_2 ) + 1 ] = False
941966
942967 with torch .no_grad ():
943968 accelerator .unwrap_model (text_encoder_1 ).get_input_embeddings ().weight [
944969 index_no_updates
945970 ] = orig_embeds_params [index_no_updates ]
971+ accelerator .unwrap_model (text_encoder_2 ).get_input_embeddings ().weight [
972+ index_no_updates_2
973+ ] = orig_embeds_params_2 [index_no_updates_2 ]
946974
947975 # Checks if the accelerator has performed an optimization step behind the scenes
948976 if accelerator .sync_gradients :
@@ -960,6 +988,16 @@ def main():
960988 save_path ,
961989 safe_serialization = True ,
962990 )
991+ weight_name = f"learned_embeds_2-steps-{ global_step } .safetensors"
992+ save_path = os .path .join (args .output_dir , weight_name )
993+ save_progress (
994+ text_encoder_2 ,
995+ placeholder_token_ids_2 ,
996+ accelerator ,
997+ args ,
998+ save_path ,
999+ safe_serialization = True ,
1000+ )
9631001
9641002 if accelerator .is_main_process :
9651003 if global_step % args .checkpointing_steps == 0 :
@@ -1034,7 +1072,7 @@ def main():
10341072 pipeline = DiffusionPipeline .from_pretrained (
10351073 args .pretrained_model_name_or_path ,
10361074 text_encoder = accelerator .unwrap_model (text_encoder_1 ),
1037- text_encoder_2 = text_encoder_2 ,
1075+ text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
10381076 vae = vae ,
10391077 unet = unet ,
10401078 tokenizer = tokenizer_1 ,
@@ -1052,6 +1090,16 @@ def main():
10521090 save_path ,
10531091 safe_serialization = True ,
10541092 )
1093+ weight_name = "learned_embeds_2.safetensors"
1094+ save_path = os .path .join (args .output_dir , weight_name )
1095+ save_progress (
1096+ text_encoder_2 ,
1097+ placeholder_token_ids_2 ,
1098+ accelerator ,
1099+ args ,
1100+ save_path ,
1101+ safe_serialization = True ,
1102+ )
10551103
10561104 if args .push_to_hub :
10571105 save_model_card (
0 commit comments