Skip to content

Commit c4b8254

Browse files
authored
fix: missing unk, fix bug (#251)
* wip * fix: tokenizer bug * add correct scaling for byte
1 parent 22011b7 commit c4b8254

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

model2vec/tokenizer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def _process_unigram(
4040
def _calculate_token_weight_for_unigram(token: str) -> float:
4141
"""Calculate the token weight for Unigram."""
4242
# Always prefer longer tokens.
43-
return len(token) + token.count("▁")
43+
return len(token) + token.count("▁") + token.count("Ġ")

model2vec/tokenizer/tokenizer.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def replace_vocabulary(
6565

6666
# Remove old added tokens from added tokens
6767
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
68-
tokenizer_json = process_tokenizer(tokenizer_json, pre_tokenized_tokens, "[UNK]")
68+
tokenizer_json = process_tokenizer(
69+
tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None
70+
)
6971

7072
# Remap special tokens
7173
tokenizer_json["added_tokens"] = _remap_added_tokens(
@@ -111,11 +113,11 @@ def clean_and_create_vocabulary(
111113
internal_vocab: dict[str, int] = tokenizer.get_vocab()
112114
internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]
113115

116+
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
114117
# Copy the backend tokenizer to avoid modifying the original.
115118
backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
116119
backend_tokenizer = replace_normalizer(backend_tokenizer)
117120

118-
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
119121
internal_tokens_set = {token.form for token in cleaned_vocabulary}
120122

121123
normalizer: Normalizer | None = backend_tokenizer.normalizer
@@ -302,15 +304,9 @@ def turn_tokens_into_ids(
302304
:param tokenizer: The tokenizer to use for converting tokens to IDs
303305
:param unk_token: The string form of the unk token.
304306
:return: List of token IDs corresponding to the input tokens
305-
:raises ValueError: If the tokenizer returns an unexpected number of tokens for a single token
306307
"""
307308
unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token)
308-
309-
encoding = tokenizer.encode("a", add_special_tokens=True)
310-
311-
if len(encoding) != 3:
312-
raise ValueError(f"Tokenizer returned {len(encoding)} tokens for a single token. This is not supported.")
313-
bos, _, eos = encoding
309+
prefix, suffix = find_eos_bos(tokenizer)
314310

315311
token_ids: list[list[int]] = []
316312
for token in tokens:
@@ -321,13 +317,30 @@ def turn_tokens_into_ids(
321317
# Explicitly check and warn if `unk_id` appears, but don't crash.
322318
if unk_id is not None and token_id == unk_id and token.form != unk_token:
323319
logger.warning(f"Token {token.form} was set to unk. This is wrong.")
324-
token_ids.append([bos, token_id, eos])
320+
token_ids.append([*prefix, token_id, *suffix])
325321
else:
326322
token_ids.append(tokenizer.encode(token.form))
327323

328324
return token_ids
329325

330326

327+
def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]:
328+
"""Finds the eos and bos tokens for a tokenizer."""
329+
# Little bit complicated, because not all tokenizers have eos and bos tokens.
330+
encoding = tokenizer.encode("a", add_special_tokens=True)
331+
if len(encoding) != 3:
332+
a_encoded = tokenizer.encode("a", add_special_tokens=False)
333+
if len(a_encoded) != 1:
334+
raise ValueError(
335+
f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'"
336+
)
337+
a_idx = encoding.index(a_encoded[0])
338+
prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :]
339+
else:
340+
prefix, suffix = encoding[:1], encoding[2:]
341+
return prefix, suffix
342+
343+
331344
def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
332345
"""Normalize a token that is not in the initial token vocabulary."""
333346
# Add prefix space for byte tokenizers.

0 commit comments

Comments
 (0)