Skip to content

Commit 02f5591

Browse files
committed
tests: add a lot of tests
1 parent f6a27a4 commit 02f5591

File tree

4 files changed

+142
-12
lines changed

4 files changed

+142
-12
lines changed

model2vec/tokenizer/normalizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def replace_normalizer(
2626
new_normalizers.append(Replace(Regex(r"\s+"), " "))
2727
new_normalizers.append(Strip(right=True))
2828
if normalizer is None:
29-
normalizer = Sequence(new_normalizers)
29+
normalizer = Sequence(new_normalizers) # type: ignore
3030
else:
3131
normalizer = Sequence([normalizer] + new_normalizers) # type: ignore
32-
tokenizer.normalizer = normalizer
32+
tokenizer.normalizer = normalizer # type: ignore
3333

3434
return tokenizer

model2vec/tokenizer/tokenizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,13 @@ def create_tokenizer(
361361
token_remove_regex: re.Pattern | None = None,
362362
) -> PreTrainedTokenizerFast:
363363
"""
364-
Create a tokenizer from a vocabulary.
364+
Create a tokenizer by adding tokens to the vocabulary.
365365
366-
This function creates a tokenizer from a vocabulary and a tokenizer.
367-
It also sets the normalizer and pre-tokenizer for the tokenizer.
366+
This function turns any tokenizer into a supertoken tokenizer. It does the following:
367+
1. Turns the tokenizer model into a unigram model.
368+
2. Adds a new pretokenizer, splitting on punctuation.
369+
3. Adds all tokens in vocabulary to the model.
370+
4. Removes any internal tokens that conform to the regex.
368371
369372
:param tokenizer: The tokenizer to use.
370373
:param vocabulary: The vocabulary to use.

tests/conftest.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, cast
44

55
import numpy as np
66
import pytest
77
import torch
88
from tokenizers import Tokenizer
99
from tokenizers.models import BPE, Unigram, WordPiece
1010
from tokenizers.pre_tokenizers import Whitespace
11-
from transformers import AutoModel, AutoTokenizer
11+
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
1212

1313
from model2vec.inference import StaticModelPipeline
1414
from model2vec.train import StaticModelForClassification
@@ -25,7 +25,9 @@ def mock_tokenizer(request: pytest.FixtureRequest) -> Tokenizer:
2525
tokenizer_type = request.param
2626

2727
if tokenizer_type == "wordpiece":
28-
model = WordPiece(vocab={token: idx for idx, token in enumerate(vocab)}, unk_token=unk_token)
28+
model = WordPiece(
29+
vocab={token: idx for idx, token in enumerate(vocab)}, unk_token=unk_token, max_input_chars_per_word=100
30+
)
2931
elif tokenizer_type == "bpe":
3032
model = BPE(
3133
vocab={token: idx for idx, token in enumerate(vocab)},
@@ -35,17 +37,19 @@ def mock_tokenizer(request: pytest.FixtureRequest) -> Tokenizer:
3537
ignore_merges=True,
3638
)
3739
elif tokenizer_type == "unigram":
38-
model = Unigram(vocab=[(token, 0.0) for token in vocab], unk_id=0)
40+
model = Unigram(vocab=[(token, 0.0) for token in vocab], unk_id=0, byte_fallback=False)
41+
else:
42+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
3943
tokenizer = Tokenizer(model)
40-
tokenizer.pre_tokenizer = Whitespace()
44+
tokenizer.pre_tokenizer = Whitespace() # type: ignore # Tokenizer issue
4145

4246
return tokenizer
4347

4448

4549
@pytest.fixture(scope="function")
46-
def mock_berttokenizer() -> AutoTokenizer:
50+
def mock_berttokenizer() -> PreTrainedTokenizerFast:
4751
"""Load the real BertTokenizerFast from the provided tokenizer.json file."""
48-
return AutoTokenizer.from_pretrained("tests/data/test_tokenizer")
52+
return cast(PreTrainedTokenizerFast, AutoTokenizer.from_pretrained("tests/data/test_tokenizer"))
4953

5054

5155
@pytest.fixture

tests/test_tokenizer.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import json
2+
3+
import pytest
4+
from transformers import PreTrainedTokenizerFast
5+
6+
from model2vec.tokenizer.model import _calculate_token_weight_for_unigram, _process_unigram, process_tokenizer
7+
from model2vec.tokenizer.normalizer import replace_normalizer
8+
from model2vec.tokenizer.pretokenizer import _FORBIDDEN_PRETOKENIZERS, _fix_single_pretokenizer, replace_pretokenizer
9+
from model2vec.tokenizer.tokenizer import _rename_added_token, create_tokenizer
10+
11+
12+
def test_fix_single_pretokenizer() -> None:
13+
"""Test the _fix_single_pretokenizer function."""
14+
result = _fix_single_pretokenizer({"type": "ByteLevel", "add_prefix_space": False, "use_regex": True})
15+
assert result == {"type": "ByteLevel", "add_prefix_space": True, "use_regex": False}
16+
17+
for tokenizer_type in _FORBIDDEN_PRETOKENIZERS:
18+
result = _fix_single_pretokenizer({"type": tokenizer_type})
19+
assert result is None
20+
21+
result = _fix_single_pretokenizer(
22+
{"type": "Metaspace", "split": True, "prepend_scheme": "never", "replacement": "▁"}
23+
)
24+
assert result == {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
25+
26+
27+
def test_replace_pretokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
28+
"""Test the replace_pretokenizer function."""
29+
tokenizer = replace_pretokenizer(mock_berttokenizer.backend_tokenizer)
30+
assert tokenizer.pre_tokenizer is not None
31+
assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace"
32+
assert tokenizer.pre_tokenizer.replacement == "▁"
33+
assert tokenizer.pre_tokenizer.prepend_scheme == "always"
34+
assert not tokenizer.pre_tokenizer.split
35+
36+
tokenizer.pre_tokenizer = None # type: ignore
37+
tokenizer = replace_pretokenizer(tokenizer)
38+
assert tokenizer.pre_tokenizer is not None
39+
assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace"
40+
assert tokenizer.pre_tokenizer.replacement == "▁"
41+
assert tokenizer.pre_tokenizer.prepend_scheme == "always"
42+
assert tokenizer.pre_tokenizer.split is False
43+
44+
45+
def test_replace_normalizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
46+
"""Test the replace_normalizer function."""
47+
tokenizer = replace_normalizer(mock_berttokenizer.backend_tokenizer)
48+
assert tokenizer.normalizer is not None
49+
assert tokenizer.normalizer.__class__.__name__ == "Sequence"
50+
51+
assert tokenizer.normalizer.normalize_str("Hello, World!") == "hello , world !"
52+
53+
tokenizer.normalizer = None # type: ignore
54+
tokenizer = replace_normalizer(tokenizer)
55+
assert tokenizer.normalizer.normalize_str("Hello, World!") == "Hello , World !"
56+
57+
58+
@pytest.mark.parametrize(
59+
"word,weight",
60+
[
61+
("dog", 3),
62+
("cat", 3),
63+
("▁longer▁word", 14),
64+
("▁word", 6),
65+
("▁", 2), # Single underscore
66+
("", 0), # Empty string
67+
("▁a" * 100, 300), # Long word with underscores
68+
],
69+
)
70+
def test_calculate_token_weight_for_unigram(word: str, weight: int) -> None:
71+
"""Test the _calculate_token_weight_for_unigram function."""
72+
assert _calculate_token_weight_for_unigram(word) == weight
73+
74+
75+
def test_process_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
76+
"""Test the process_tokenizer function."""
77+
vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"]
78+
tokenizer_json = json.loads(mock_berttokenizer.backend_tokenizer.to_str())
79+
tokenizer_json = process_tokenizer(tokenizer_json=tokenizer_json, pre_tokenized_tokens=vocab, unk_token="[UNK]")
80+
81+
assert tokenizer_json["model"]["type"] == "Unigram"
82+
assert tokenizer_json["model"]["unk_id"] == 5 # Index of "[UNK]"
83+
assert len(tokenizer_json["model"]["vocab"]) == 6
84+
assert all(isinstance(token, tuple) and len(token) == 2 for token in tokenizer_json["model"]["vocab"])
85+
for (x, _), y in zip(tokenizer_json["model"]["vocab"], vocab):
86+
assert x == y, f"Expected {y}, but got {x}"
87+
88+
89+
def test_process_unigram() -> None:
90+
"""Test the _process_unigram function."""
91+
vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"]
92+
orig_vocab = [("dog", 0), ("cat", 0)]
93+
model = {"model": {"type": "Unigram", "vocab": orig_vocab}}
94+
processed_model = _process_unigram(model, vocab, "[UNK]")
95+
assert processed_model["model"]["type"] == "Unigram"
96+
assert processed_model["model"]["unk_id"] == 5 # Index of "[UNK]"
97+
assert len(processed_model["model"]["vocab"]) == 6
98+
assert all(isinstance(token, list) and len(token) == 2 for token in processed_model["model"]["vocab"])
99+
100+
for (x, score), y in zip(processed_model["model"]["vocab"], vocab):
101+
assert x == y, f"Expected {y}, but got {x}"
102+
if x in orig_vocab:
103+
assert score == 0
104+
105+
assert process_tokenizer(model, vocab, "[UNK]") == processed_model
106+
107+
108+
def test_rename_added_token() -> None:
109+
"""Test the _rename_added_token function."""
110+
# Invalid input
111+
result = _rename_added_token(None, "a", [{"content": "a", "id": 0}], ["a"])
112+
assert result == [{"content": "a", "id": 0}]
113+
114+
# Rename 'a' to 'c'
115+
result = _rename_added_token("a", "c", [{"content": "a"}], ["a"])
116+
assert result == [{"content": "c", "id": 0}]
117+
118+
119+
def test_create_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
120+
"""Test the create_tokenizer function."""
121+
tokenizer = create_tokenizer(tokenizer=mock_berttokenizer, vocabulary=["dog", "catssssss"], token_remove_regex=None)
122+
assert tokenizer.backend_tokenizer.get_vocab_size() == 29525
123+
assert tokenizer.encode("catssssss") == [29524]

0 commit comments

Comments
 (0)