Skip to content

Commit 25f5e06

Browse files
committed
fix t5 training bug
1 parent a4c1aac commit 25f5e06

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def save_embeddings(self, file_path: str):
881881
for idx, text_encoder in enumerate(self.text_encoders):
882882
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
883883
embeds = (
884-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
884+
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
885885
)
886886
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
887887
new_token_embeddings = embeds.weight.data[train_ids]
@@ -905,7 +905,7 @@ def device(self):
905905
def retract_embeddings(self):
906906
for idx, text_encoder in enumerate(self.text_encoders):
907907
embeds = (
908-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
908+
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
909909
)
910910
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
911911
embeds.weight.data[index_no_updates] = (
@@ -1749,7 +1749,7 @@ def load_model_hook(models, input_dir):
17491749
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
17501750
text_lora_parameters_two = []
17511751
for name, param in text_encoder_two.named_parameters():
1752-
if "token_embedding" in name:
1752+
if "shared" in name:
17531753
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
17541754
param.data = param.to(dtype=torch.float32)
17551755
param.requires_grad = True

0 commit comments

Comments
 (0)