|
12 | 12 | from tokenizers import Encoding, Tokenizer |
13 | 13 | from tqdm import tqdm |
14 | 14 |
|
15 | | -from model2vec.quantization import DType, quantize_and_reduce_dim, quantize_vocabulary |
| 15 | +from model2vec.quantization import DType |
16 | 16 | from model2vec.utils import ProgressParallel, load_local_model |
17 | 17 |
|
18 | 18 | PathLike = Union[Path, str] |
@@ -63,7 +63,7 @@ def __init__( |
63 | 63 | self.weights = weights |
64 | 64 | # Convert to an array for fast lookups |
65 | 65 | # We can't use or short circuit here because np.ndarray as booleans are ambiguous. |
66 | | - self.token_mapping = None if token_mapping is None else np.asarray(token_mapping) |
| 66 | + self.token_mapping: np.ndarray | None = None if token_mapping is None else np.asarray(token_mapping) |
67 | 67 |
|
68 | 68 | self.tokenizer = tokenizer |
69 | 69 | self.unk_token_id: int | None |
@@ -194,39 +194,16 @@ def from_pretrained( |
194 | 194 | :param vocabulary_quantization: The number of clusters to use for vocabulary quantization. |
195 | 195 | :return: A StaticModel. |
196 | 196 | """ |
197 | | - from model2vec.hf_utils import load_pretrained |
198 | | - |
199 | | - embeddings, tokenizer, config, metadata, weights = load_pretrained( |
200 | | - folder_or_repo_path=path, |
| 197 | + return _loading_helper( |
| 198 | + cls=cls, |
| 199 | + path=path, |
201 | 200 | token=token, |
202 | | - from_sentence_transformers=False, |
203 | | - subfolder=subfolder, |
204 | | - ) |
205 | | - |
206 | | - # Quantize the vocabulary at full precision and dimensionality |
207 | | - if vocabulary_quantization is not None: |
208 | | - embeddings, token_mapping, weights = quantize_vocabulary( |
209 | | - n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings |
210 | | - ) |
211 | | - else: |
212 | | - token_mapping = config.pop("token_mapping", None) |
213 | | - |
214 | | - # Reduce dimensionality and quantize if requested |
215 | | - embeddings = quantize_and_reduce_dim( |
216 | | - embeddings=embeddings, |
| 201 | + vocabulary_quantization=vocabulary_quantization, |
217 | 202 | quantize_to=quantize_to, |
218 | 203 | dimensionality=dimensionality, |
219 | | - ) |
220 | | - |
221 | | - return cls( |
222 | | - vectors=embeddings, |
223 | | - tokenizer=tokenizer, |
224 | | - weights=weights, |
225 | | - token_mapping=token_mapping, |
226 | | - config=config, |
| 204 | + from_sentence_transformers=False, |
227 | 205 | normalize=normalize, |
228 | | - base_model_name=metadata.get("base_model"), |
229 | | - language=metadata.get("language"), |
| 206 | + subfolder=subfolder, |
230 | 207 | ) |
231 | 208 |
|
232 | 209 | @classmethod |
@@ -255,38 +232,16 @@ def from_sentence_transformers( |
255 | 232 | :param vocabulary_quantization: The number of clusters to use for vocabulary quantization. |
256 | 233 | :return: A StaticModel. |
257 | 234 | """ |
258 | | - from model2vec.hf_utils import load_pretrained |
259 | | - |
260 | | - embeddings, tokenizer, config, metadata, weights = load_pretrained( |
261 | | - folder_or_repo_path=path, |
| 235 | + return _loading_helper( |
| 236 | + cls=cls, |
| 237 | + path=path, |
262 | 238 | token=token, |
263 | | - from_sentence_transformers=True, |
264 | | - ) |
265 | | - |
266 | | - # Quantize the vocabulary at full precision and dimensionality |
267 | | - if vocabulary_quantization is not None: |
268 | | - embeddings, token_mapping, weights = quantize_vocabulary( |
269 | | - n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings |
270 | | - ) |
271 | | - else: |
272 | | - token_mapping = config.pop("token_mapping", None) |
273 | | - |
274 | | - # Reduce dimensionality and quantize if requested |
275 | | - embeddings = quantize_and_reduce_dim( |
276 | | - embeddings=embeddings, |
| 239 | + vocabulary_quantization=vocabulary_quantization, |
277 | 240 | quantize_to=quantize_to, |
278 | 241 | dimensionality=dimensionality, |
279 | | - ) |
280 | | - |
281 | | - return cls( |
282 | | - vectors=embeddings, |
283 | | - tokenizer=tokenizer, |
284 | | - weights=weights, |
285 | | - token_mapping=token_mapping, |
286 | | - config=config, |
| 242 | + from_sentence_transformers=True, |
287 | 243 | normalize=normalize, |
288 | | - base_model_name=metadata.get("base_model"), |
289 | | - language=metadata.get("language"), |
| 244 | + subfolder=None, |
290 | 245 | ) |
291 | 246 |
|
292 | 247 | @overload |
@@ -381,7 +336,7 @@ def _encode_batch_as_sequence(self, sentences: Sequence[str], max_length: int | |
381 | 336 | out: list[np.ndarray] = [] |
382 | 337 | for id_list in ids: |
383 | 338 | if id_list: |
384 | | - out.append(self.embedding[id_list]) |
| 339 | + out.append(self._encode_helper(id_list)) |
385 | 340 | else: |
386 | 341 | out.append(np.zeros((0, self.dim))) |
387 | 342 |
|
@@ -450,23 +405,35 @@ def encode( |
450 | 405 | return out_array[0] |
451 | 406 | return out_array |
452 | 407 |
|
| 408 | + def _encode_helper(self, id_list: list[int]) -> np.ndarray: |
| 409 | + """ |
| 410 | + Helper function to encode a list of ids. |
| 411 | +
|
| 412 | + This function is used to deduplicate the logic in `encode` and `encode_as_sequence`. |
| 413 | + It retrieves the embeddings for the given list of ids, applying weights if available. |
| 414 | +
|
| 415 | + :param id_list: A list of token ids. |
| 416 | + :return: The embeddings for the given ids, as a sequence of vectors. |
| 417 | + """ |
| 418 | + id_list_remapped: list[int] | np.ndarray |
| 419 | + if self.token_mapping is None: |
| 420 | + id_list_remapped = id_list |
| 421 | + else: |
| 422 | + id_list_remapped = self.token_mapping[id_list] |
| 423 | + emb = self.embedding[id_list_remapped] |
| 424 | + if self.weights is not None: |
| 425 | + emb = emb * self.weights[id_list][:, None] |
| 426 | + |
| 427 | + return emb |
| 428 | + |
453 | 429 | def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.ndarray: |
454 | 430 | """Encode a batch of sentences.""" |
455 | 431 | ids = self.tokenize(sentences=sentences, max_length=max_length) |
456 | 432 | out: list[np.ndarray] = [] |
457 | 433 | for id_list in ids: |
458 | 434 | if id_list: |
459 | | - id_list_remapped: list[int] | np.ndarray |
460 | | - if self.token_mapping is None: |
461 | | - id_list_remapped = id_list |
462 | | - else: |
463 | | - id_list_remapped = self.token_mapping[id_list] |
464 | | - emb = self.embedding[id_list_remapped] |
465 | | - if self.weights is not None: |
466 | | - emb = emb * self.weights[id_list][:, None] |
467 | | - emb = emb.mean(axis=0) |
468 | | - |
469 | | - out.append(emb) |
| 435 | + emb = self._encode_helper(id_list) |
| 436 | + out.append(emb.mean(axis=0)) |
470 | 437 | else: |
471 | 438 | out.append(np.zeros(self.dim)) |
472 | 439 |
|
@@ -529,3 +496,101 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel: |
529 | 496 | return StaticModel( |
530 | 497 | vectors=embeddings, tokenizer=tokenizer, config=config, weights=weights, token_mapping=token_mapping |
531 | 498 | ) |
| 499 | + |
| 500 | + |
| 501 | +def quantize_model( |
| 502 | + model: StaticModel, |
| 503 | + vocabulary_quantization: int | None = None, |
| 504 | + quantize_to: str | DType | None = None, |
| 505 | + dimensionality: int | None = None, |
| 506 | +) -> StaticModel: |
| 507 | + """ |
| 508 | + Quantize the model to a lower precision and possibly lower dimensionality. |
| 509 | +
|
| 510 | + :param model: The model to quantize. |
| 511 | + :param vocabulary_quantization: The number of clusters to use for quantization. |
| 512 | + :param quantize_to: The dtype to quantize the model to. |
| 513 | + :param dimensionality: The desired dimensionality of the model. |
| 514 | + This needs to be < than the current model dimensionality. |
| 515 | + :return: A new StaticModel with the quantized embeddings. |
| 516 | + :raises: ValueError if the model is already quantized. |
| 517 | + """ |
| 518 | + from model2vec.quantization import quantize_and_reduce_dim |
| 519 | + |
| 520 | + token_mapping: list[int] | None |
| 521 | + weights: np.ndarray | None |
| 522 | + if vocabulary_quantization is not None: |
| 523 | + from model2vec.vocabulary_quantization import quantize_vocabulary |
| 524 | + |
| 525 | + if len(model.tokens) != len(model.embedding): |
| 526 | + raise ValueError("Model already has been vocabulary quantized, cannot quantize again.") |
| 527 | + |
| 528 | + embeddings, token_mapping, weights = quantize_vocabulary( |
| 529 | + n_clusters=vocabulary_quantization, weights=model.weights, embeddings=model.embedding |
| 530 | + ) |
| 531 | + else: |
| 532 | + embeddings = model.embedding |
| 533 | + token_mapping = cast(list[int], model.token_mapping.tolist()) if model.token_mapping is not None else None |
| 534 | + weights = model.weights |
| 535 | + if quantize_to is not None or dimensionality is not None: |
| 536 | + embeddings = quantize_and_reduce_dim( |
| 537 | + embeddings=embeddings, |
| 538 | + quantize_to=quantize_to, |
| 539 | + dimensionality=dimensionality, |
| 540 | + ) |
| 541 | + |
| 542 | + return StaticModel( |
| 543 | + vectors=embeddings, |
| 544 | + tokenizer=model.tokenizer, |
| 545 | + config=model.config, |
| 546 | + weights=weights, |
| 547 | + token_mapping=token_mapping, |
| 548 | + normalize=model.normalize, |
| 549 | + base_model_name=model.base_model_name, |
| 550 | + language=model.language, |
| 551 | + ) |
| 552 | + |
| 553 | + |
| 554 | +def _loading_helper( |
| 555 | + cls: type[StaticModel], |
| 556 | + path: PathLike, |
| 557 | + token: str | None, |
| 558 | + vocabulary_quantization: int | None = None, |
| 559 | + quantize_to: str | DType | None = None, |
| 560 | + dimensionality: int | None = None, |
| 561 | + from_sentence_transformers: bool = False, |
| 562 | + normalize: bool | None = None, |
| 563 | + subfolder: str | None = None, |
| 564 | +) -> StaticModel: |
| 565 | + """Helper function to load a model from a directory.""" |
| 566 | + from model2vec.hf_utils import load_pretrained |
| 567 | + |
| 568 | + if from_sentence_transformers and subfolder is not None: |
| 569 | + raise ValueError("Subfolder is not supported for sentence transformers models.") |
| 570 | + |
| 571 | + embeddings, tokenizer, config, metadata, weights = load_pretrained( |
| 572 | + folder_or_repo_path=path, |
| 573 | + token=token, |
| 574 | + from_sentence_transformers=from_sentence_transformers, |
| 575 | + subfolder=subfolder, |
| 576 | + ) |
| 577 | + |
| 578 | + token_mapping = config.pop("token_mapping", None) |
| 579 | + |
| 580 | + model = cls( |
| 581 | + vectors=embeddings, |
| 582 | + tokenizer=tokenizer, |
| 583 | + weights=weights, |
| 584 | + token_mapping=token_mapping, |
| 585 | + config=config, |
| 586 | + normalize=normalize, |
| 587 | + base_model_name=metadata.get("base_model"), |
| 588 | + language=metadata.get("language"), |
| 589 | + ) |
| 590 | + |
| 591 | + return quantize_model( |
| 592 | + model=model, |
| 593 | + vocabulary_quantization=vocabulary_quantization, |
| 594 | + quantize_to=quantize_to, |
| 595 | + dimensionality=dimensionality, |
| 596 | + ) |
0 commit comments