Skip to content

Commit e031caf

Browse files
[flux lora training] fix t5 training bug (huggingface#10845)
* fix t5 training bug * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 08f74a8 commit e031caf

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str):
880880
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
881881
for idx, text_encoder in enumerate(self.text_encoders):
882882
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
883-
embeds = (
884-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
885-
)
883+
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
886884
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
887885
new_token_embeddings = embeds.weight.data[train_ids]
888886

@@ -904,9 +902,7 @@ def device(self):
904902
@torch.no_grad()
905903
def retract_embeddings(self):
906904
for idx, text_encoder in enumerate(self.text_encoders):
907-
embeds = (
908-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
909-
)
905+
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
910906
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
911907
embeds.weight.data[index_no_updates] = (
912908
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
@@ -1749,7 +1745,7 @@ def load_model_hook(models, input_dir):
17491745
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
17501746
text_lora_parameters_two = []
17511747
for name, param in text_encoder_two.named_parameters():
1752-
if "token_embedding" in name:
1748+
if "shared" in name:
17531749
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
17541750
param.data = param.to(dtype=torch.float32)
17551751
param.requires_grad = True

0 commit comments

Comments
 (0)