@@ -881,7 +881,7 @@ def save_embeddings(self, file_path: str):
881881 for idx , text_encoder in enumerate (self .text_encoders ):
882882 train_ids = self .train_ids if idx == 0 else self .train_ids_t5
883883 embeds = (
884- text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .encoder . embed_tokens
884+ text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .shared
885885 )
886886 assert embeds .weight .data .shape [0 ] == len (self .tokenizers [idx ]), "Tokenizers should be the same."
887887 new_token_embeddings = embeds .weight .data [train_ids ]
@@ -905,7 +905,7 @@ def device(self):
905905 def retract_embeddings (self ):
906906 for idx , text_encoder in enumerate (self .text_encoders ):
907907 embeds = (
908- text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .encoder . embed_tokens
908+ text_encoder .text_model .embeddings .token_embedding if idx == 0 else text_encoder .shared
909909 )
910910 index_no_updates = self .embeddings_settings [f"index_no_updates_{ idx } " ]
911911 embeds .weight .data [index_no_updates ] = (
@@ -1749,7 +1749,7 @@ def load_model_hook(models, input_dir):
17491749 if args .enable_t5_ti : # whether to do pivotal tuning/textual inversion for T5 as well
17501750 text_lora_parameters_two = []
17511751 for name , param in text_encoder_two .named_parameters ():
1752- if "token_embedding " in name :
1752+ if "shared " in name :
17531753 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
17541754 param .data = param .to (dtype = torch .float32 )
17551755 param .requires_grad = True
0 commit comments