99from transformers import AutoModel , AutoTokenizer , PreTrainedModel , PreTrainedTokenizerFast
1010
1111from model2vec .distill .inference import PCADimType , create_embeddings , post_process_embeddings
12- from model2vec .distill .tokenizer import clean_and_create_vocabulary , replace_vocabulary , turn_tokens_into_ids
1312from model2vec .distill .utils import select_optimal_device
1413from model2vec .model import StaticModel
1514from model2vec .quantization import DType , quantize_embeddings
15+ from model2vec .tokenizer import clean_and_create_vocabulary , replace_vocabulary , turn_tokens_into_ids
1616
1717try :
1818 # For huggingface_hub>=0.25.0
@@ -91,8 +91,11 @@ def distill_from_model(
9191 raise ValueError ("The vocabulary is empty after preprocessing. Please check your token_remove_pattern." )
9292
9393 # Create the embeddings.
94- unk_token = tokenizer .special_tokens_map .get ("unk_token" )
95- pad_token = tokenizer .special_tokens_map .get ("pad_token" )
94+ unk_token : str | None = tokenizer .special_tokens_map .get ("unk_token" )
95+ pad_token : str | None = tokenizer .special_tokens_map .get ("pad_token" )
96+
97+ # Add the cleaned vocabulary to the tokenizer.
98+ backend_tokenizer = replace_vocabulary (backend_tokenizer , all_tokens , unk_token = unk_token , pad_token = pad_token )
9699
97100 # Convert tokens to IDs
98101 token_ids = turn_tokens_into_ids (all_tokens , tokenizer , unk_token )
@@ -101,8 +104,6 @@ def distill_from_model(
101104 tokenized = token_ids , model = model , device = device , pad_token_id = tokenizer .get_vocab ()[pad_token ]
102105 )
103106
104- # Add the cleaned vocabulary to the tokenizer.
105- backend_tokenizer = replace_vocabulary (backend_tokenizer , all_tokens , unk_token = unk_token , pad_token = pad_token )
106107 # Post process the embeddings by applying PCA and Zipf weighting.
107108 embeddings = post_process_embeddings (np .asarray (embeddings ), pca_dims , sif_coefficient = sif_coefficient )
108109 # Quantize the embeddings.
0 commit comments