Skip to content

Commit 87f61fe

Browse files
committed
Merge branch 'main' into vocquant
2 parents 9fe7e33 + c4b8254 commit 87f61fe

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11

2-
<div align="center">
3-
<picture>
4-
<img width="35%" alt="Model2Vec logo" src="assets/images/logo_v2.png">
5-
</picture>
6-
</a>
7-
</div>
2+
<h2 align="center">
3+
<img width="35%" alt="Model2Vec logo" src="assets/images/model2vec_logo.png"><br/>
4+
Fast State-of-the-Art Static Embeddings
5+
</h2>
6+
87

9-
<div align="center">
10-
<h2>Fast State-of-the-Art Static Embeddings</h2>
11-
</div>
128

139
<div align="center">
1410
<h2>

assets/images/model2vec_logo.png

1.55 MB
Loading

model2vec/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def encode_as_sequence(
371371
return out_array[0]
372372
return out_array
373373

374-
def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None) -> list[np.ndarray]:
374+
def _encode_batch_as_sequence(self, sentences: Sequence[str], max_length: int | None) -> list[np.ndarray]:
375375
"""Encode a batch of sentences as a sequence."""
376376
ids = self.tokenize(sentences=sentences, max_length=max_length)
377377
out: list[np.ndarray] = []

model2vec/tokenizer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def _process_unigram(
4040
def _calculate_token_weight_for_unigram(token: str) -> float:
4141
"""Calculate the token weight for Unigram."""
4242
# Always prefer longer tokens.
43-
return len(token) + token.count("▁")
43+
return len(token) + token.count("▁") + token.count("Ġ")

model2vec/tokenizer/tokenizer.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def replace_vocabulary(
6565

6666
# Remove old added tokens from added tokens
6767
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
68-
tokenizer_json = process_tokenizer(tokenizer_json, pre_tokenized_tokens, "[UNK]")
68+
tokenizer_json = process_tokenizer(
69+
tokenizer_json, pre_tokenized_tokens, "[UNK]" if "[UNK]" in pre_tokenized_tokens else None
70+
)
6971

7072
# Remap special tokens
7173
tokenizer_json["added_tokens"] = _remap_added_tokens(
@@ -111,11 +113,11 @@ def clean_and_create_vocabulary(
111113
internal_vocab: dict[str, int] = tokenizer.get_vocab()
112114
internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]
113115

116+
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
114117
# Copy the backend tokenizer to avoid modifying the original.
115118
backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
116119
backend_tokenizer = replace_normalizer(backend_tokenizer)
117120

118-
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
119121
internal_tokens_set = {token.form for token in cleaned_vocabulary}
120122

121123
normalizer: Normalizer | None = backend_tokenizer.normalizer
@@ -302,15 +304,9 @@ def turn_tokens_into_ids(
302304
:param tokenizer: The tokenizer to use for converting tokens to IDs
303305
:param unk_token: The string form of the unk token.
304306
:return: List of token IDs corresponding to the input tokens
305-
:raises ValueError: If the tokenizer returns an unexpected number of tokens for a single token
306307
"""
307308
unk_id = None if unk_token is None else tokenizer.convert_tokens_to_ids(unk_token)
308-
309-
encoding = tokenizer.encode("a", add_special_tokens=True)
310-
311-
if len(encoding) != 3:
312-
raise ValueError(f"Tokenizer returned {len(encoding)} tokens for a single token. This is not supported.")
313-
bos, _, eos = encoding
309+
prefix, suffix = find_eos_bos(tokenizer)
314310

315311
token_ids: list[list[int]] = []
316312
for token in tokens:
@@ -321,13 +317,30 @@ def turn_tokens_into_ids(
321317
# Explicitly check and warn if `unk_id` appears, but don't crash.
322318
if unk_id is not None and token_id == unk_id and token.form != unk_token:
323319
logger.warning(f"Token {token.form} was set to unk. This is wrong.")
324-
token_ids.append([bos, token_id, eos])
320+
token_ids.append([*prefix, token_id, *suffix])
325321
else:
326322
token_ids.append(tokenizer.encode(token.form))
327323

328324
return token_ids
329325

330326

327+
def find_eos_bos(tokenizer: PreTrainedTokenizerFast) -> tuple[list[int], list[int]]:
328+
"""Finds the eos and bos tokens for a tokenizer."""
329+
# Little bit complicated, because not all tokenizers have eos and bos tokens.
330+
encoding = tokenizer.encode("a", add_special_tokens=True)
331+
if len(encoding) != 3:
332+
a_encoded = tokenizer.encode("a", add_special_tokens=False)
333+
if len(a_encoded) != 1:
334+
raise ValueError(
335+
f"Error while encoding, couldn't determine eos and bos tokens. The model tokenizes 'a' to '{a_encoded}'"
336+
)
337+
a_idx = encoding.index(a_encoded[0])
338+
prefix, suffix = encoding[:a_idx], encoding[a_idx + 1 :]
339+
else:
340+
prefix, suffix = encoding[:1], encoding[2:]
341+
return prefix, suffix
342+
343+
331344
def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
332345
"""Normalize a token that is not in the initial token vocabulary."""
333346
# Add prefix space for byte tokenizers.

0 commit comments

Comments
 (0)