Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def unload_textual_inversion(
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
# set correct total vocab size after removing tokens
tokenizer._update_total_vocab_size()

# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
Expand Down
21 changes: 21 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,27 @@ def test_text_inversion_multi_tokens(self):
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
)

def test_textual_inversion_unload(self):
pipe1 = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe1 = pipe1.to(torch_device)
orig_tokenizer_size = len(pipe1.tokenizer)
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)

token = "<*>"
ten = torch.ones((32,))
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()

final_tokenizer_size = len(pipe1.tokenizer)
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
# both should be restored to original size
assert final_tokenizer_size == orig_tokenizer_size
assert final_emb_size == orig_emb_size

def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down
Loading