Skip to content

Commit 1666dd2

Browse files
committed
working version
1 parent f4d6a82 commit 1666dd2

File tree

4 files changed

+108
-64
lines changed

4 files changed

+108
-64
lines changed

model2vec/distill/distillation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,8 @@ def distill_from_model(
107107
pad_token = tokenizer.special_tokens_map.get("pad_token")
108108
# Add the cleaned vocabulary to the tokenizer.
109109
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
110-
111110
# Post process the embeddings by applying PCA and Zipf weighting.
112111
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
113-
114112
# Quantize the embeddings.
115113
embeddings = quantize_embeddings(embeddings, quantize_to)
116114

model2vec/distill/inference.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from tokenizers.models import BPE, Unigram, WordPiece
1213
from torch.nn.utils.rnn import pad_sequence
1314
from tqdm import tqdm
1415
from transformers import PreTrainedModel, PreTrainedTokenizerFast
@@ -71,6 +72,23 @@ def create_embeddings(
7172
# If the token remove regex is None, just use all tokens.
7273
id_list = list(range(len(tokenizer.get_vocab())))
7374

75+
id_set = set(id_list)
76+
new_id_list = []
77+
for token, idx in tokenizer.get_vocab().items():
78+
if idx not in id_set:
79+
continue
80+
81+
if (
82+
tokenizer.backend_tokenizer.pre_tokenizer is not None
83+
and not token.startswith("##")
84+
and not token in tokens_to_keep
85+
):
86+
pre_token = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(token)
87+
if len(pre_token) > 1:
88+
continue
89+
new_id_list.append(idx)
90+
id_list = new_id_list
91+
7492
added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
7593
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()
7694

@@ -82,7 +100,28 @@ def create_embeddings(
82100
eos = torch.full([len(ids)], fill_value=eos_token_id)
83101

84102
tokenized.extend(torch.stack([bos, ids, eos], dim=1))
85-
subword_tokens = [Token(x, True) for x in tokenizer.convert_ids_to_tokens(ids.tolist())]
103+
104+
subword_tokens = []
105+
for token in tokenizer.convert_ids_to_tokens(ids.tolist()):
106+
is_subword = False
107+
should_be_pretokenized = True
108+
if token == unk_token or token == pad_token:
109+
is_subword = True
110+
elif isinstance(tokenizer.backend_tokenizer.model, WordPiece):
111+
prefix_char = tokenizer.backend_tokenizer.model.continuing_subword_prefix
112+
if token.startswith(prefix_char):
113+
is_subword = True
114+
elif isinstance(tokenizer.backend_tokenizer.model, Unigram):
115+
if not token.startswith("▁"):
116+
is_subword = True
117+
elif isinstance(tokenizer.backend_tokenizer.model, BPE):
118+
if not token.startswith("Ġ"):
119+
is_subword = True
120+
should_be_pretokenized = False
121+
else:
122+
should_be_pretokenized = False
123+
subword_tokens.append(Token(token, is_subword, should_be_pretokenized))
124+
86125
out_tokens.extend(subword_tokens)
87126

88127
tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens])
@@ -113,7 +152,7 @@ def create_embeddings(
113152

114153
# Sort the output back to the original order
115154
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
116-
out_tokens.extend([Token(x, False) for x in tokens])
155+
out_tokens.extend([Token(x, False, True) for x in tokens])
117156
out_weights = np.stack(intermediate_weights)
118157

119158
return out_tokens, out_weights

model2vec/distill/tokenizer.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
import json
44
import logging
5+
from string import punctuation
56
from typing import Any
67

7-
from tokenizers import Tokenizer
8+
from tokenizers import Regex, Tokenizer
9+
from tokenizers.normalizers import Lowercase, Normalizer, Replace, Strip
10+
from tokenizers.normalizers import Sequence as NormalizerSequence
811
from tokenizers.pre_tokenizers import (
912
BertPreTokenizer,
1013
ByteLevel,
1114
CharDelimiterSplit,
12-
Digits,
1315
Metaspace,
1416
PreTokenizer,
1517
Punctuation,
@@ -45,7 +47,7 @@
4547
}
4648

4749

48-
def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[str]:
50+
def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token], subword_prefix: str) -> list[str]:
4951
"""
5052
Apply pre-tokenization to vocabulary tokens if a pre-tokenizer is present.
5153
@@ -54,19 +56,28 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
5456
5557
:param tokenizer: The tokenizer to use.
5658
:param tokens: The tokens to pre-tokenize.
59+
:param subword_prefix: The prefix for subwords.
5760
:return: The pre-tokenized tokens.
5861
"""
5962
pre_tokenized_tokens = []
6063

6164
if tokenizer.pre_tokenizer is not None:
6265
for token in tokens:
63-
if token.is_original:
66+
if token.is_subword:
6467
# Original tokens do not need to be pre-tokenized.
65-
pre_tokenized_tokens.append(token.form)
66-
else:
68+
form = token.form
69+
if subword_prefix is not None:
70+
form = token.form.removeprefix(subword_prefix)
71+
pre_tokenized_tokens.append(form)
72+
elif token.should_be_pretokenized:
6773
# Join tokens just to be sure.
74+
token.form = tokenizer.normalizer.normalize_str(token.form).rstrip()
6875
pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(token.form))
69-
pre_tokenized_tokens.append(" ".join(pretokenized_tokens))
76+
form = " ".join(pretokenized_tokens)
77+
pre_tokenized_tokens.append(form)
78+
else:
79+
token.form = tokenizer.normalizer.normalize_str(token.form).rstrip()
80+
pre_tokenized_tokens.append(token.form)
7081
else:
7182
pre_tokenized_tokens = [token.form for token in tokens]
7283

@@ -95,12 +106,38 @@ def _remap_added_tokens(
95106
return special_tokens
96107

97108

109+
def _prepare_normalizer(
110+
normalizer: Normalizer,
111+
) -> Normalizer:
112+
"""
113+
Prepare the normalizer for the tokenizer.
114+
115+
This function sets the normalizer for the tokenizer based on the provided normalizer type.
116+
If no normalizer is provided, it uses the default one.
117+
118+
:param normalizer: The tokenizer to prepare.
119+
:return: The prepared tokenizer.
120+
"""
121+
new_normalizers = []
122+
for char in punctuation:
123+
new_normalizers.append(Replace(char, f" {char} "))
124+
new_normalizers.append(Replace(Regex(r"\s+"), " "))
125+
new_normalizers.append(Strip(right=True))
126+
if normalizer is None:
127+
return NormalizerSequence(new_normalizers)
128+
129+
return NormalizerSequence([normalizer] + new_normalizers)
130+
131+
98132
def _fix_single_pretokenizer(pretokenizer: PreTokenizer) -> PreTokenizer | None:
99133
"""Fixes a single pretokenizer to allow multiword units."""
134+
if isinstance(pretokenizer, Metaspace):
135+
return Metaspace(split=False, replacement=pretokenizer.replacement, prepend_scheme=pretokenizer.prepend_scheme)
100136
if isinstance(pretokenizer, _FORBIDDEN_PRETOKENIZERS):
101-
return Metaspace(split=False, replacement="Ġ")
137+
return Metaspace(split=False, replacement="")
102138
elif isinstance(pretokenizer, ByteLevel):
103139
pretokenizer.use_regex = False
140+
pretokenizer.add_prefix_space = True
104141

105142
return pretokenizer
106143

@@ -111,68 +148,29 @@ def _fix_pretokenizer_for_super(pre: PreTokenizer | None) -> Tokenizer:
111148
return pre
112149

113150
if isinstance(pre, Sequence):
114-
new_pretokenizers = []
115-
for pretokenizer in pre:
116-
new_pretokenizers.append(_fix_single_pretokenizer(pretokenizer))
117-
return Sequence(new_pretokenizers)
151+
return Metaspace(split=False)
118152

119153
return _fix_single_pretokenizer(pre)
120154

121155

122-
def _make_new_merges_from_vocab(
123-
merges: list[tuple[str, str]], tokens: list[str], special_tokens: set[str | None]
124-
) -> list[tuple[str, str]]:
125-
"""
126-
Generate new merges from a vocabulary.
127-
128-
This function creates new merge pairs from a given vocabulary of tokens.
129-
The merges are used to build or extend a tokenizer's merge table.
130-
131-
:param merges: The list of existing merges in the form (first, second) where first and second are tokens.
132-
:param tokens: The list of tokens (vocabulary) from which to generate new merges.
133-
:param special_tokens: Tokens that should not be merged.
134-
:return: The list of new merges in the form (first, second) where first and second are tokens.
135-
"""
136-
new_merges = merges.copy()
137-
current_vocab = set(tokens) - special_tokens
138-
already_merged = set("".join(merge) for merge in merges)
139-
140-
for token in tokens:
141-
if token in special_tokens:
142-
continue
143-
if token in already_merged:
144-
continue
145-
if len(token) == 1:
146-
continue
147-
merges = []
148-
for index in range(1, len(token)):
149-
first, second = token[:index], token[index:]
150-
if first in current_vocab and second in current_vocab:
151-
merges.append((first, second))
152-
if not merges:
153-
logger.warning(f"Token {token} has no merges.")
154-
continue
155-
new_merges.extend(merges)
156-
157-
return new_merges
158-
159-
160156
def _process_wordpiece(
161157
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
162158
) -> dict[str, Any]:
163159
"""Process the WordPiece tokenizer JSON."""
164-
tokenizer_json["model"]["unk_token"] = unk_token
165-
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
160+
tokenizer_json["model"]["type"] = "Unigram"
161+
tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
162+
tokenizer_json["model"]["vocab"] = [(token, 0.0) for token in pre_tokenized_tokens]
166163

167164
return tokenizer_json
168165

169166

170-
def _process_bpe(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str]) -> dict[str, Any]:
167+
def _process_bpe(
168+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
169+
) -> dict[str, Any]:
171170
"""Process the BPE tokenizer JSON."""
172-
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, None)
173-
merges = tokenizer_json["model"]["merges"]
174-
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, {"[UNK]", "[PAD]"})
175-
tokenizer_json["model"]["merges"] = merges
171+
tokenizer_json["model"]["type"] = "Unigram"
172+
tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
173+
tokenizer_json["model"]["vocab"] = [(token, 0.0) for token in pre_tokenized_tokens]
176174

177175
return tokenizer_json
178176

@@ -194,13 +192,16 @@ def replace_vocabulary(
194192
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
195193
) -> Tokenizer:
196194
"""Replace the vocabulary of a tokenizer with a new one."""
195+
tokenizer.normalizer = _prepare_normalizer(tokenizer.normalizer)
197196
tokenizer.pre_tokenizer = _fix_pretokenizer_for_super(tokenizer.pre_tokenizer)
198197
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
199198

200199
# NOTE: all tokens have been normalized before.
201200
# Very careful, we need to pretokenize words before adding them to the vocabulary.
202201
# But only if they are not part of the original vocabulary.
203-
pre_tokenized_tokens = _pre_tokenize_vocabulary(tokenizer, new_vocabulary)
202+
subword_prefix = tokenizer_json["model"].get("continuing_subword_prefix", "")
203+
204+
pre_tokenized_tokens = _pre_tokenize_vocabulary(tokenizer, new_vocabulary, subword_prefix=subword_prefix)
204205

205206
model_type = tokenizer_json["model"]["type"]
206207
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
@@ -215,7 +216,7 @@ def replace_vocabulary(
215216
if model_type == "WordPiece":
216217
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, "[UNK]")
217218
elif model_type == "BPE":
218-
tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens)
219+
tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens, "[UNK]")
219220
elif model_type == "Unigram":
220221
tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, "[UNK]")
221222
else:

model2vec/distill/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@ class Token:
1414
"""A class to represent a token."""
1515

1616
form: str
17-
is_original: bool
17+
# Whether the word is a continuing subword.
18+
is_subword: bool
19+
# Whether it should be pretokenized.
20+
# This is independent of is_subword, because some
21+
# tokenizer models like BPE and Unigram do not have a
22+
# continuing subword prefix, but instead prefix nonsubwords.
23+
should_be_pretokenized: bool
1824

1925

2026
def select_optimal_device(device: str | None) -> str:

0 commit comments

Comments
 (0)