@@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str):
880880 idx_to_text_encoder_name = {0 : "clip_l" , 1 : "t5" }
881881 for idx , text_encoder in enumerate (self .text_encoders ):
882882 train_ids = self .train_ids if idx == 0 else self .train_ids_t5
883- embeds = (
884- text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .encoder .embed_tokens
885- )
883+ embeds = text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .shared
886884 assert embeds .weight .data .shape [0 ] == len (self .tokenizers [idx ]), "Tokenizers should be the same."
887885 new_token_embeddings = embeds .weight .data [train_ids ]
888886
@@ -904,9 +902,7 @@ def device(self):
904902 @torch .no_grad ()
905903 def retract_embeddings (self ):
906904 for idx , text_encoder in enumerate (self .text_encoders ):
907- embeds = (
908- text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .encoder .embed_tokens
909- )
905+ embeds = text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .shared
910906 index_no_updates = self .embeddings_settings [f"index_no_updates_{ idx } " ]
911907 embeds .weight .data [index_no_updates ] = (
912908 self .embeddings_settings [f"original_embeddings_{ idx } " ][index_no_updates ]
@@ -1749,7 +1745,7 @@ def load_model_hook(models, input_dir):
17491745 if args .enable_t5_ti : # whether to do pivotal tuning/textual inversion for T5 as well
17501746 text_lora_parameters_two = []
17511747 for name , param in text_encoder_two .named_parameters ():
1752- if "token_embedding " in name :
1748+ if "shared " in name :
17531749 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
17541750 param .data = param .to (dtype = torch .float32 )
17551751 param .requires_grad = True
0 commit comments