Skip to content

Commit 5d3e7bd

Browse files
bonlimesayakpaulYour Nameyiyixuxu
authored
Fix bug in Textual Inversion Unloading (huggingface#9304)
* Update textual_inversion.py * add unload test * add comment * fix style --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Your Name <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 2541d14 commit 5d3e7bd

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/diffusers/loaders/textual_inversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,8 @@ def unload_textual_inversion(
561561
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
562562
key_id += 1
563563
tokenizer._update_trie()
564+
# set correct total vocab size after removing tokens
565+
tokenizer._update_total_vocab_size()
564566

565567
# Delete from text encoder
566568
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim

tests/pipelines/test_pipelines.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,27 @@ def test_text_inversion_multi_tokens(self):
947947
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
948948
)
949949

950+
def test_textual_inversion_unload(self):
951+
pipe1 = StableDiffusionPipeline.from_pretrained(
952+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
953+
)
954+
pipe1 = pipe1.to(torch_device)
955+
orig_tokenizer_size = len(pipe1.tokenizer)
956+
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
957+
958+
token = "<*>"
959+
ten = torch.ones((32,))
960+
pipe1.load_textual_inversion(ten, token=token)
961+
pipe1.unload_textual_inversion()
962+
pipe1.load_textual_inversion(ten, token=token)
963+
pipe1.unload_textual_inversion()
964+
965+
final_tokenizer_size = len(pipe1.tokenizer)
966+
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
967+
# both should be restored to original size
968+
assert final_tokenizer_size == orig_tokenizer_size
969+
assert final_emb_size == orig_emb_size
970+
950971
def test_download_ignore_files(self):
951972
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
952973
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)