55from logging import getLogger
66from pathlib import Path
77from tempfile import TemporaryDirectory
8- from typing import Any , Iterator , Sequence , Union , overload
8+ from typing import Any , Iterator , Sequence , Union , cast , overload
99
1010import numpy as np
1111from joblib import delayed
1212from tokenizers import Encoding , Tokenizer
1313from tqdm import tqdm
1414
15- from model2vec .quantization import DType , quantize_and_reduce_dim , vocabulary_quantization
15+ from model2vec .quantization import DType , quantize_and_reduce_dim , quantize_vocabulary
1616from model2vec .utils import ProgressParallel , load_local_model
1717
1818PathLike = Union [Path , str ]
@@ -25,12 +25,12 @@ def __init__(
2525 self ,
2626 vectors : np .ndarray ,
2727 tokenizer : Tokenizer ,
28- weights : np .ndarray | None = None ,
29- token_mapping : dict [int , int ] | None = None ,
3028 config : dict [str , Any ] | None = None ,
3129 normalize : bool | None = None ,
3230 base_model_name : str | None = None ,
3331 language : list [str ] | None = None ,
32+ weights : np .ndarray | None = None ,
33+ token_mapping : list [int ] | None = None ,
3434 ) -> None :
3535 """
3636 Initialize the StaticModel.
@@ -41,6 +41,12 @@ def __init__(
4141 :param normalize: Whether to normalize the embeddings.
4242 :param base_model_name: The used base model name. Used for creating a model card.
4343 :param language: The language of the model. Used for creating a model card.
44+ :param weights: The weights to use for the embeddings. If None, no weights are used.
45+ We always assume the norm of the embeddings is an implicit weight anyway.
46+ This is only used for models that have undergone vocabulary quantization.
47+ :param token_mapping: A mapping from token ids to indices in the vectors.
48+ If None, we don't remap the tokens during inference.
49+ This is only used for models that have undergone vocabulary quantization.
4450 :raises: ValueError if the number of tokens does not match the number of vectors.
4551 """
4652 super ().__init__ ()
@@ -55,7 +61,9 @@ def __init__(
5561
5662 self .embedding = vectors
5763 self .weights = weights
58- self .token_mapping = token_mapping
64+ # Convert to an array for fast lookups
65+ # We can't use or short circuit here because np.ndarray as booleans are ambiguous.
66+ self .token_mapping = None if token_mapping is None else np .asarray (token_mapping )
5967
6068 self .tokenizer = tokenizer
6169 self .unk_token_id : int | None
@@ -114,7 +122,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
114122 from model2vec .hf_utils import save_pretrained
115123
116124 if self .token_mapping is not None :
117- self .config ["token_mapping" ] = list ( self .token_mapping .items () )
125+ self .config ["token_mapping" ] = self .token_mapping .tolist ( )
118126
119127 save_pretrained (
120128 folder_path = Path (path ),
@@ -167,7 +175,7 @@ def from_pretrained(
167175 subfolder : str | None = None ,
168176 quantize_to : str | DType | None = None ,
169177 dimensionality : int | None = None ,
170- quantize_vocabulary : int | None = None ,
178+ vocabulary_quantization : int | None = None ,
171179 ) -> StaticModel :
172180 """
173181 Load a StaticModel from a local path or huggingface hub path.
@@ -183,6 +191,7 @@ def from_pretrained(
183191 :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
184192 This is useful if you want to load a model with a lower dimensionality.
185193 Note that this only applies if you have trained your model using mrl or PCA.
194+ :param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
186195 :return: A StaticModel.
187196 """
188197 from model2vec .hf_utils import load_pretrained
@@ -194,31 +203,27 @@ def from_pretrained(
194203 subfolder = subfolder ,
195204 )
196205
206+ # Quantize the vocabulary at full precision and dimensionality
207+ if vocabulary_quantization is not None :
208+ embeddings , token_mapping , weights = quantize_vocabulary (
209+ n_clusters = vocabulary_quantization , weights = weights , embeddings = embeddings
210+ )
211+ else :
212+ token_mapping = config .pop ("token_mapping" , None )
213+
214+ # Reduce dimensionality and quantize if requested
197215 embeddings = quantize_and_reduce_dim (
198216 embeddings = embeddings ,
199217 quantize_to = quantize_to ,
200218 dimensionality = dimensionality ,
201219 )
202220
203- if quantize_vocabulary is not None :
204- embeddings , token_mapping , weights = vocabulary_quantization (
205- n_clusters = quantize_vocabulary , weights = weights , embeddings = embeddings
206- )
207- else :
208- token_mapping = config .pop ("token_mapping" , None )
209- if isinstance (token_mapping , list ):
210- # If the token mapping is a list, convert it to a dict
211- token_mapping = {int (k ): int (v ) for k , v in token_mapping }
212- elif token_mapping is None :
213- # If no token mapping is provided, use the default mapping
214- token_mapping = {i : i for i in range (len (embeddings ))}
215-
216221 return cls (
217- embeddings ,
218- tokenizer ,
219- weights ,
220- token_mapping ,
221- config ,
222+ vectors = embeddings ,
223+ tokenizer = tokenizer ,
224+ weights = weights ,
225+ token_mapping = token_mapping ,
226+ config = config ,
222227 normalize = normalize ,
223228 base_model_name = metadata .get ("base_model" ),
224229 language = metadata .get ("language" ),
@@ -232,6 +237,7 @@ def from_sentence_transformers(
232237 normalize : bool | None = None ,
233238 quantize_to : str | DType | None = None ,
234239 dimensionality : int | None = None ,
240+ vocabulary_quantization : int | None = None ,
235241 ) -> StaticModel :
236242 """
237243 Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -246,6 +252,7 @@ def from_sentence_transformers(
246252 :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
247253 This is useful if you want to load a model with a lower dimensionality.
248254 Note that this only applies if you have trained your model using mrl or PCA.
255+ :param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
249256 :return: A StaticModel.
250257 """
251258 from model2vec .hf_utils import load_pretrained
@@ -254,26 +261,29 @@ def from_sentence_transformers(
254261 folder_or_repo_path = path ,
255262 token = token ,
256263 from_sentence_transformers = True ,
257- subfolder = None ,
258264 )
259265
266+ # Quantize the vocabulary at full precision and dimensionality
267+ if vocabulary_quantization is not None :
268+ embeddings , token_mapping , weights = quantize_vocabulary (
269+ n_clusters = vocabulary_quantization , weights = weights , embeddings = embeddings
270+ )
271+ else :
272+ token_mapping = config .pop ("token_mapping" , None )
273+
274+ # Reduce dimensionality and quantize if requested
260275 embeddings = quantize_and_reduce_dim (
261276 embeddings = embeddings ,
262277 quantize_to = quantize_to ,
263278 dimensionality = dimensionality ,
264279 )
265280
266- token_mapping = config .pop ("token_mapping" , None )
267- if token_mapping is None :
268- # If no token mapping is provided, use the default mapping
269- token_mapping = {i : i for i in range (len (embeddings ))}
270-
271281 return cls (
272- embeddings ,
273- tokenizer ,
274- weights ,
275- token_mapping ,
276- config ,
282+ vectors = embeddings ,
283+ tokenizer = tokenizer ,
284+ weights = weights ,
285+ token_mapping = token_mapping ,
286+ config = config ,
277287 normalize = normalize ,
278288 base_model_name = metadata .get ("base_model" ),
279289 language = metadata .get ("language" ),
@@ -446,10 +456,11 @@ def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.
446456 out : list [np .ndarray ] = []
447457 for id_list in ids :
448458 if id_list :
459+ id_list_remapped : list [int ] | np .ndarray
449460 if self .token_mapping is None :
450461 id_list_remapped = id_list
451462 else :
452- id_list_remapped = [ self .token_mapping . get ( token_id , token_id ) for token_id in id_list ]
463+ id_list_remapped = self .token_mapping [ id_list ]
453464 emb = self .embedding [id_list_remapped ]
454465 if self .weights is not None :
455466 emb = emb * self .weights [id_list ][:, None ]
@@ -512,6 +523,9 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
512523 if not path .is_dir ():
513524 raise ValueError (f"Path { path } is not a directory." )
514525
515- embeddings , tokenizer , config = load_local_model (path )
526+ embeddings , tokenizer , config , weights = load_local_model (path )
527+ token_mapping = cast (list [int ], config .pop ("token_mapping" , None ))
516528
517- return StaticModel (embeddings , tokenizer , config = config )
529+ return StaticModel (
530+ vectors = embeddings , tokenizer = tokenizer , config = config , weights = weights , token_mapping = token_mapping
531+ )
0 commit comments