|
| 1 | +import json |
| 2 | + |
| 3 | +import pytest |
| 4 | +from transformers import PreTrainedTokenizerFast |
| 5 | + |
| 6 | +from model2vec.tokenizer.model import _calculate_token_weight_for_unigram, _process_unigram, process_tokenizer |
| 7 | +from model2vec.tokenizer.normalizer import replace_normalizer |
| 8 | +from model2vec.tokenizer.pretokenizer import _FORBIDDEN_PRETOKENIZERS, _fix_single_pretokenizer, replace_pretokenizer |
| 9 | +from model2vec.tokenizer.tokenizer import _rename_added_token, create_tokenizer |
| 10 | + |
| 11 | + |
| 12 | +def test_fix_single_pretokenizer() -> None: |
| 13 | + """Test the _fix_single_pretokenizer function.""" |
| 14 | + result = _fix_single_pretokenizer({"type": "ByteLevel", "add_prefix_space": False, "use_regex": True}) |
| 15 | + assert result == {"type": "ByteLevel", "add_prefix_space": True, "use_regex": False} |
| 16 | + |
| 17 | + for tokenizer_type in _FORBIDDEN_PRETOKENIZERS: |
| 18 | + result = _fix_single_pretokenizer({"type": tokenizer_type}) |
| 19 | + assert result is None |
| 20 | + |
| 21 | + result = _fix_single_pretokenizer( |
| 22 | + {"type": "Metaspace", "split": True, "prepend_scheme": "never", "replacement": "▁"} |
| 23 | + ) |
| 24 | + assert result == {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False} |
| 25 | + |
| 26 | + |
| 27 | +def test_replace_pretokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: |
| 28 | + """Test the replace_pretokenizer function.""" |
| 29 | + tokenizer = replace_pretokenizer(mock_berttokenizer.backend_tokenizer) |
| 30 | + assert tokenizer.pre_tokenizer is not None |
| 31 | + assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace" |
| 32 | + assert tokenizer.pre_tokenizer.replacement == "▁" |
| 33 | + assert tokenizer.pre_tokenizer.prepend_scheme == "always" |
| 34 | + assert not tokenizer.pre_tokenizer.split |
| 35 | + |
| 36 | + tokenizer.pre_tokenizer = None # type: ignore |
| 37 | + tokenizer = replace_pretokenizer(tokenizer) |
| 38 | + assert tokenizer.pre_tokenizer is not None |
| 39 | + assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace" |
| 40 | + assert tokenizer.pre_tokenizer.replacement == "▁" |
| 41 | + assert tokenizer.pre_tokenizer.prepend_scheme == "always" |
| 42 | + assert tokenizer.pre_tokenizer.split is False |
| 43 | + |
| 44 | + |
| 45 | +def test_replace_normalizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: |
| 46 | + """Test the replace_normalizer function.""" |
| 47 | + tokenizer = replace_normalizer(mock_berttokenizer.backend_tokenizer) |
| 48 | + assert tokenizer.normalizer is not None |
| 49 | + assert tokenizer.normalizer.__class__.__name__ == "Sequence" |
| 50 | + |
| 51 | + assert tokenizer.normalizer.normalize_str("Hello, World!") == "hello , world !" |
| 52 | + |
| 53 | + tokenizer.normalizer = None # type: ignore |
| 54 | + tokenizer = replace_normalizer(tokenizer) |
| 55 | + assert tokenizer.normalizer.normalize_str("Hello, World!") == "Hello , World !" |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.parametrize( |
| 59 | + "word,weight", |
| 60 | + [ |
| 61 | + ("dog", 3), |
| 62 | + ("cat", 3), |
| 63 | + ("▁longer▁word", 14), |
| 64 | + ("▁word", 6), |
| 65 | + ("▁", 2), # Single underscore |
| 66 | + ("", 0), # Empty string |
| 67 | + ("▁a" * 100, 300), # Long word with underscores |
| 68 | + ], |
| 69 | +) |
| 70 | +def test_calculate_token_weight_for_unigram(word: str, weight: int) -> None: |
| 71 | + """Test the _calculate_token_weight_for_unigram function.""" |
| 72 | + assert _calculate_token_weight_for_unigram(word) == weight |
| 73 | + |
| 74 | + |
| 75 | +def test_process_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: |
| 76 | + """Test the process_tokenizer function.""" |
| 77 | + vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"] |
| 78 | + tokenizer_json = json.loads(mock_berttokenizer.backend_tokenizer.to_str()) |
| 79 | + tokenizer_json = process_tokenizer(tokenizer_json=tokenizer_json, pre_tokenized_tokens=vocab, unk_token="[UNK]") |
| 80 | + |
| 81 | + assert tokenizer_json["model"]["type"] == "Unigram" |
| 82 | + assert tokenizer_json["model"]["unk_id"] == 5 # Index of "[UNK]" |
| 83 | + assert len(tokenizer_json["model"]["vocab"]) == 6 |
| 84 | + assert all(isinstance(token, tuple) and len(token) == 2 for token in tokenizer_json["model"]["vocab"]) |
| 85 | + for (x, _), y in zip(tokenizer_json["model"]["vocab"], vocab): |
| 86 | + assert x == y, f"Expected {y}, but got {x}" |
| 87 | + |
| 88 | + |
| 89 | +def test_process_unigram() -> None: |
| 90 | + """Test the _process_unigram function.""" |
| 91 | + vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"] |
| 92 | + orig_vocab = [("dog", 0), ("cat", 0)] |
| 93 | + model = {"model": {"type": "Unigram", "vocab": orig_vocab}} |
| 94 | + processed_model = _process_unigram(model, vocab, "[UNK]") |
| 95 | + assert processed_model["model"]["type"] == "Unigram" |
| 96 | + assert processed_model["model"]["unk_id"] == 5 # Index of "[UNK]" |
| 97 | + assert len(processed_model["model"]["vocab"]) == 6 |
| 98 | + assert all(isinstance(token, list) and len(token) == 2 for token in processed_model["model"]["vocab"]) |
| 99 | + |
| 100 | + for (x, score), y in zip(processed_model["model"]["vocab"], vocab): |
| 101 | + assert x == y, f"Expected {y}, but got {x}" |
| 102 | + if x in orig_vocab: |
| 103 | + assert score == 0 |
| 104 | + |
| 105 | + assert process_tokenizer(model, vocab, "[UNK]") == processed_model |
| 106 | + |
| 107 | + |
| 108 | +def test_rename_added_token() -> None: |
| 109 | + """Test the _rename_added_token function.""" |
| 110 | + # Invalid input |
| 111 | + result = _rename_added_token(None, "a", [{"content": "a", "id": 0}], ["a"]) |
| 112 | + assert result == [{"content": "a", "id": 0}] |
| 113 | + |
| 114 | + # Rename 'a' to 'c' |
| 115 | + result = _rename_added_token("a", "c", [{"content": "a"}], ["a"]) |
| 116 | + assert result == [{"content": "c", "id": 0}] |
| 117 | + |
| 118 | + |
| 119 | +def test_create_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: |
| 120 | + """Test the create_tokenizer function.""" |
| 121 | + tokenizer = create_tokenizer(tokenizer=mock_berttokenizer, vocabulary=["dog", "catssssss"], token_remove_regex=None) |
| 122 | + assert tokenizer.backend_tokenizer.get_vocab_size() == 29525 |
| 123 | + assert tokenizer.encode("catssssss") == [29524] |
0 commit comments