Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def save_embeddings(self, file_path: str):
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
)
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
Expand All @@ -905,7 +905,7 @@ def device(self):
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
)
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
Expand Down Expand Up @@ -1749,7 +1749,7 @@ def load_model_hook(models, input_dir):
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
Expand Down
Loading