Skip to content

Commit 69d28b5

Browse files
committed
fix indices of t5 pivotal tuning embeddings
1 parent bd2be32 commit 69d28b5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def __init__(self, text_encoders, tokenizers):
810810
self.tokenizers = tokenizers
811811

812812
self.train_ids: Optional[torch.Tensor] = None
813+
self.train_ids_t5: Optional[torch.Tensor] = None
813814
self.inserting_toks: Optional[List[str]] = None
814815
self.embeddings_settings = {}
815816

@@ -828,7 +829,10 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
828829
text_encoder.resize_token_embeddings(len(tokenizer))
829830

830831
# Convert the token abstractions to ids
831-
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
832+
if idx == 0:
833+
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
834+
else:
835+
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
832836

833837
# random initialization of new tokens
834838
embeds = (
@@ -838,19 +842,20 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
838842

839843
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
840844

845+
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
841846
# if initializer_concept are not provided, token embeddings are initialized randomly
842847
if args.initializer_concept is None:
843848
hidden_size = (
844849
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
845850
)
846-
embeds.weight.data[self.train_ids] = (
847-
torch.randn(len(self.train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
851+
embeds.weight.data[train_ids] = (
852+
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
848853
* std_token_embedding
849854
)
850855
else:
851856
# Convert the initializer_token, placeholder_token to ids
852857
initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False)
853-
for token_idx, token_id in enumerate(self.train_ids):
858+
for token_idx, token_id in enumerate(train_ids):
854859
embeds.weight.data[token_id] = (embeds.weight.data)[
855860
initializer_token_ids[token_idx % len(initializer_token_ids)]
856861
].clone()
@@ -860,7 +865,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
860865

861866
# makes sure we don't update any embedding weights besides the newly added token
862867
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
863-
index_no_updates[self.train_ids] = False
868+
index_no_updates[train_ids] = False
864869

865870
self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates
866871

@@ -874,11 +879,12 @@ def save_embeddings(self, file_path: str):
874879
# text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl
875880
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
876881
for idx, text_encoder in enumerate(self.text_encoders):
882+
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
877883
embeds = (
878884
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
879885
)
880886
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
881-
new_token_embeddings = embeds.weight.data[self.train_ids]
887+
new_token_embeddings = embeds.weight.data[train_ids]
882888

883889
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0),
884890
# Note: When loading with diffusers, any name can work - simply specify in inference

0 commit comments

Comments
 (0)