Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def unload_textual_inversion(
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
tokenizer._update_total_vocab_size()

# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,29 @@ def test_text_inversion_multi_tokens(self):
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
)

def test_textual_inversion_unload(self):
pipe1 = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe1 = pipe1.to(torch_device)
orig_tokenizer_size = len(pipe1.tokenizer)
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)

token = "<*>"
ten = torch.ones((32,))
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()


final_tokenizer_size = len(pipe1.tokenizer)
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
# both should be restored to original size
assert final_tokenizer_size == orig_tokenizer_size
assert final_emb_size == orig_emb_size


def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down
Loading