Skip to content

Commit e2789ba

Browse files
committed
turn tokenizer into package
1 parent a972c10 commit e2789ba

File tree

9 files changed

+169
-137
lines changed

9 files changed

+169
-137
lines changed

model2vec/distill/distillation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
1010

1111
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
12-
from model2vec.distill.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
1312
from model2vec.distill.utils import select_optimal_device
1413
from model2vec.model import StaticModel
1514
from model2vec.quantization import DType, quantize_embeddings
15+
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
1616

1717
try:
1818
# For huggingface_hub>=0.25.0
@@ -91,8 +91,11 @@ def distill_from_model(
9191
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")
9292

9393
# Create the embeddings.
94-
unk_token = tokenizer.special_tokens_map.get("unk_token")
95-
pad_token = tokenizer.special_tokens_map.get("pad_token")
94+
unk_token: str | None = tokenizer.special_tokens_map.get("unk_token")
95+
pad_token: str | None = tokenizer.special_tokens_map.get("pad_token")
96+
97+
# Add the cleaned vocabulary to the tokenizer.
98+
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
9699

97100
# Convert tokens to IDs
98101
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
@@ -101,8 +104,6 @@ def distill_from_model(
101104
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
102105
)
103106

104-
# Add the cleaned vocabulary to the tokenizer.
105-
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
106107
# Post process the embeddings by applying PCA and Zipf weighting.
107108
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
108109
# Quantize the embeddings.

model2vec/distill/inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def create_embeddings(
7979
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
8080
out_weights = np.stack(intermediate_weights)
8181

82+
out_weights = np.nan_to_num(out_weights)
83+
8284
return out_weights
8385

8486

model2vec/distill/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
from __future__ import annotations
22

3-
import re
4-
from dataclasses import dataclass
53
from logging import getLogger
64

75
import torch
86

97
logger = getLogger(__name__)
108

119

12-
@dataclass
13-
class Token:
14-
"""A class to represent a token."""
15-
16-
form: str
17-
# The normalized and pretokenized form of the token
18-
normalized_form: str
19-
# Whether the word is a continuing subword.
20-
is_subword: bool
21-
# Whether the token is internal to the model.
22-
is_internal: bool
23-
24-
2510
def select_optimal_device(device: str | None) -> str:
2611
"""
2712
Guess what your optimal device should be based on backend availability.

model2vec/tokenizer/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from model2vec.utils import importable
2+
3+
importable("transformers", "tokenizer")
4+
5+
from model2vec.tokenizer.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
6+
7+
__all__ = ["clean_and_create_vocabulary", "turn_tokens_into_ids", "replace_vocabulary"]

model2vec/tokenizer/datamodels.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class Token:
6+
"""A class to represent a token."""
7+
8+
form: str
9+
# The normalized and pretokenized form of the token
10+
normalized_form: str
11+
# Whether the word is a continuing subword.
12+
is_subword: bool
13+
# Whether the token is internal to the model.
14+
is_internal: bool

model2vec/tokenizer/model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
5+
6+
def process_tokenizer(
7+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
8+
) -> dict[str, Any]:
9+
"""Process the WordPiece tokenizer JSON."""
10+
tokenizer_json["model"]["type"] = "Unigram"
11+
tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
12+
13+
token_weights = np.asarray([_calculate_token_weight_for_unigram(token) for token in pre_tokenized_tokens])
14+
proba = (token_weights / np.sum(token_weights)).tolist()
15+
tokenizer_json["model"]["vocab"] = [(token, np.log(p)) for token, p in zip(pre_tokenized_tokens, proba)]
16+
17+
return tokenizer_json
18+
19+
20+
def process_unigram(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str) -> dict[str, Any]:
21+
"""Process the Unigram tokenizer JSON."""
22+
current_probas = dict(tokenizer_json["model"]["vocab"])
23+
avg_proba = sum(current_probas.values()) / len(current_probas)
24+
new_probas = [[word, current_probas.get(word, avg_proba)] for word in pre_tokenized_tokens]
25+
tokenizer_json["model"]["vocab"] = new_probas
26+
27+
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
28+
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token)
29+
30+
return tokenizer_json
31+
32+
33+
def _calculate_token_weight_for_unigram(token: str) -> float:
34+
"""Calculate the token weight for Unigram."""
35+
# Always prefer longer tokens.
36+
return len(token) + int(token.startswith("▁"))

model2vec/tokenizer/normalizer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from string import punctuation
2+
3+
from tokenizers import Regex
4+
from tokenizers.normalizers import Normalizer, Replace, Sequence, Strip
5+
6+
7+
def prepare_normalizer(
8+
normalizer: Normalizer,
9+
) -> Normalizer:
10+
"""
11+
Prepare the normalizer for the tokenizer.
12+
13+
This function sets the normalizer for the tokenizer based on the provided normalizer type.
14+
If no normalizer is provided, it uses the default one.
15+
16+
:param normalizer: The tokenizer to prepare.
17+
:return: The prepared tokenizer.
18+
"""
19+
new_normalizers = []
20+
for char in punctuation:
21+
new_normalizers.append(Replace(char, f" {char} "))
22+
23+
new_normalizers.append(Replace(Regex(r"\s+"), " "))
24+
new_normalizers.append(Strip(right=True))
25+
if normalizer is None:
26+
return Sequence(new_normalizers)
27+
28+
return Sequence([normalizer] + new_normalizers)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Any
2+
3+
_FORBIDDEN_PRETOKENIZERS = (
4+
"WhiteSpace",
5+
"WhitespaceSplit",
6+
"BertPreTokenizer",
7+
"CharDelimiterSplit",
8+
"Punctuation",
9+
"Split",
10+
"UnicodeScripts",
11+
)
12+
_BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
13+
14+
15+
def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None:
16+
"""Fixes a single pretokenizer to allow multiword units."""
17+
if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS:
18+
return None
19+
if pre_tokenizer["type"] == "ByteLevel":
20+
pre_tokenizer["add_prefix_space"] = True
21+
pre_tokenizer["use_regex"] = False
22+
if pre_tokenizer["type"] == "Metaspace":
23+
pre_tokenizer["split"] = False
24+
pre_tokenizer["prepend_scheme"] = "always"
25+
26+
return pre_tokenizer
27+
28+
29+
def fix_pretokenizer(pretokenizer: dict[str, Any] | None) -> dict[str, Any]:
30+
"""Fixes a single pretokenizer to allow multiword units."""
31+
if pretokenizer is None:
32+
return _BASIC_METASPACE
33+
34+
if pretokenizer["type"] == "Sequence":
35+
new_pretokenizers = []
36+
for single_pretokenizer in pretokenizer["pretokenizers"]:
37+
new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer)
38+
if new_pretokenizer is not None:
39+
new_pretokenizers.append(new_pretokenizer)
40+
pretokenizer["pretokenizers"] = new_pretokenizers
41+
42+
if not pretokenizer:
43+
return _BASIC_METASPACE
44+
45+
return pretokenizer
46+
47+
single_pretokenizer = _fix_single_pretokenizer(pretokenizer)
48+
if single_pretokenizer is None:
49+
return _BASIC_METASPACE
50+
51+
return single_pretokenizer

0 commit comments

Comments
 (0)