Skip to content

Commit 8c1751e

Browse files
Apply style fixes
1 parent 5c47abf commit 8c1751e

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str):
880880
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
881881
for idx, text_encoder in enumerate(self.text_encoders):
882882
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
883-
embeds = (
884-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
885-
)
883+
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
886884
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
887885
new_token_embeddings = embeds.weight.data[train_ids]
888886

@@ -904,9 +902,7 @@ def device(self):
904902
@torch.no_grad()
905903
def retract_embeddings(self):
906904
for idx, text_encoder in enumerate(self.text_encoders):
907-
embeds = (
908-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
909-
)
905+
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
910906
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
911907
embeds.weight.data[index_no_updates] = (
912908
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]

0 commit comments

Comments
 (0)