diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 51b96ec72f10..7cb0d666fe69 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str): idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): train_ids = self.train_ids if idx == 0 else self.train_ids_t5 - embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens - ) + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[train_ids] @@ -904,9 +902,7 @@ def device(self): @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): - embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens - ) + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] embeds.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] @@ -1749,7 +1745,7 @@ def load_model_hook(models, input_dir): if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): - if "token_embedding" in name: + if "shared" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 param.data = param.to(dtype=torch.float32) param.requires_grad = True