File tree Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -561,6 +561,8 @@ def unload_textual_inversion(
561561 tokenizer ._added_tokens_encoder [token .content ] = last_special_token_id + key_id
562562 key_id += 1
563563 tokenizer ._update_trie ()
564+ # set correct total vocab size after removing tokens
565+ tokenizer ._update_total_vocab_size ()
564566
565567 # Delete from text encoder
566568 text_embedding_dim = text_encoder .get_input_embeddings ().embedding_dim
Original file line number Diff line number Diff line change @@ -947,6 +947,27 @@ def test_text_inversion_multi_tokens(self):
947947 emb1 [num_tokens + 1 ].sum ().item () == emb2 [num_tokens + 1 ].sum ().item () == emb3 [num_tokens + 1 ].sum ().item ()
948948 )
949949
950+ def test_textual_inversion_unload (self ):
951+ pipe1 = StableDiffusionPipeline .from_pretrained (
952+ "hf-internal-testing/tiny-stable-diffusion-torch" , safety_checker = None
953+ )
954+ pipe1 = pipe1 .to (torch_device )
955+ orig_tokenizer_size = len (pipe1 .tokenizer )
956+ orig_emb_size = len (pipe1 .text_encoder .get_input_embeddings ().weight )
957+
958+ token = "<*>"
959+ ten = torch .ones ((32 ,))
960+ pipe1 .load_textual_inversion (ten , token = token )
961+ pipe1 .unload_textual_inversion ()
962+ pipe1 .load_textual_inversion (ten , token = token )
963+ pipe1 .unload_textual_inversion ()
964+
965+ final_tokenizer_size = len (pipe1 .tokenizer )
966+ final_emb_size = len (pipe1 .text_encoder .get_input_embeddings ().weight )
967+ # both should be restored to original size
968+ assert final_tokenizer_size == orig_tokenizer_size
969+ assert final_emb_size == orig_emb_size
970+
950971 def test_download_ignore_files (self ):
951972 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
952973 with tempfile .TemporaryDirectory () as tmpdirname :
You can’t perform that action at this time.
0 commit comments