@@ -810,6 +810,7 @@ def __init__(self, text_encoders, tokenizers):
810810        self .tokenizers  =  tokenizers 
811811
812812        self .train_ids : Optional [torch .Tensor ] =  None 
813+         self .train_ids_t5 : Optional [torch .Tensor ] =  None 
813814        self .inserting_toks : Optional [List [str ]] =  None 
814815        self .embeddings_settings  =  {}
815816
@@ -828,7 +829,10 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
828829            text_encoder .resize_token_embeddings (len (tokenizer ))
829830
830831            # Convert the token abstractions to ids 
831-             self .train_ids  =  tokenizer .convert_tokens_to_ids (self .inserting_toks )
832+             if  idx  ==  0 :
833+                 self .train_ids  =  tokenizer .convert_tokens_to_ids (self .inserting_toks )
834+             else :
835+                 self .train_ids_t5  =  tokenizer .convert_tokens_to_ids (self .inserting_toks )
832836
833837            # random initialization of new tokens 
834838            embeds  =  (
@@ -838,19 +842,20 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
838842
839843            logger .info (f"{ idx } { std_token_embedding }  )
840844
845+             train_ids  =  self .train_ids  if  idx  ==  0  else  self .train_ids_t5 
841846            # if initializer_concept are not provided, token embeddings are initialized randomly 
842847            if  args .initializer_concept  is  None :
843848                hidden_size  =  (
844849                    text_encoder .text_model .config .hidden_size  if  idx  ==  0  else  text_encoder .encoder .config .hidden_size 
845850                )
846-                 embeds .weight .data [self . train_ids ] =  (
847-                     torch .randn (len (self . train_ids ), hidden_size ).to (device = self .device ).to (dtype = self .dtype )
851+                 embeds .weight .data [train_ids ] =  (
852+                     torch .randn (len (train_ids ), hidden_size ).to (device = self .device ).to (dtype = self .dtype )
848853                    *  std_token_embedding 
849854                )
850855            else :
851856                # Convert the initializer_token, placeholder_token to ids 
852857                initializer_token_ids  =  tokenizer .encode (args .initializer_concept , add_special_tokens = False )
853-                 for  token_idx , token_id  in  enumerate (self . train_ids ):
858+                 for  token_idx , token_id  in  enumerate (train_ids ):
854859                    embeds .weight .data [token_id ] =  (embeds .weight .data )[
855860                        initializer_token_ids [token_idx  %  len (initializer_token_ids )]
856861                    ].clone ()
@@ -860,7 +865,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
860865
861866            # makes sure we don't update any embedding weights besides the newly added token 
862867            index_no_updates  =  torch .ones ((len (tokenizer ),), dtype = torch .bool )
863-             index_no_updates [self . train_ids ] =  False 
868+             index_no_updates [train_ids ] =  False 
864869
865870            self .embeddings_settings [f"index_no_updates_{ idx }  ] =  index_no_updates 
866871
@@ -874,11 +879,12 @@ def save_embeddings(self, file_path: str):
874879        # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl 
875880        idx_to_text_encoder_name  =  {0 : "clip_l" , 1 : "t5" }
876881        for  idx , text_encoder  in  enumerate (self .text_encoders ):
882+             train_ids  =  self .train_ids  if  idx  ==  0  else  self .train_ids_t5 
877883            embeds  =  (
878884                text_encoder .text_model .embeddings .token_embedding  if  idx  ==  0  else  text_encoder .encoder .embed_tokens 
879885            )
880886            assert  embeds .weight .data .shape [0 ] ==  len (self .tokenizers [idx ]), "Tokenizers should be the same." 
881-             new_token_embeddings  =  embeds .weight .data [self . train_ids ]
887+             new_token_embeddings  =  embeds .weight .data [train_ids ]
882888
883889            # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), 
884890            # Note: When loading with diffusers, any name can work - simply specify in inference 
0 commit comments