Skip to content

Commit 3f39da4

Browse files
committed
fix: new version of skeletoken
1 parent 829fad6 commit 3f39da4

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

model2vec/distill/distillation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
from huggingface_hub import model_info
10+
from skeletoken import TokenizerModel
1011
from transformers import AutoModel, AutoTokenizer
1112
from transformers.modeling_utils import PreTrainedModel
1213
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
@@ -16,6 +17,7 @@
1617
from model2vec.model import StaticModel
1718
from model2vec.quantization import DType, quantize_embeddings
1819
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
20+
from model2vec.tokenizer.tokenizer import _patch_tokenizer
1921

2022
logger = logging.getLogger(__name__)
2123

@@ -86,7 +88,10 @@ def distill_from_model(
8688

8789
logger.info(f"Creating embeddings for {len(tokens)} tokens")
8890
# Convert tokens to IDs
89-
token_ids = turn_tokens_into_ids(tokens, tokenizer.backend_tokenizer)
91+
m = _patch_tokenizer(tokenizer=tokenizer, lower_case=False)
92+
bb = m.to_tokenizer()
93+
94+
token_ids = turn_tokens_into_ids(tokens, bb)
9095

9196
# Create the embeddings
9297
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token", None))

model2vec/tokenizer/tokenizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import cast
66

77
from skeletoken import TokenizerModel
8-
from skeletoken.addedtoken import AddedToken
8+
from skeletoken.addedtoken import AddedToken, AddedTokens
99
from skeletoken.models import WordPiece
1010
from skeletoken.pretokenizers import ByteLevelPreTokenizer, PreTokenizerSequence
1111
from tokenizers import Tokenizer
@@ -50,7 +50,7 @@ def replace_vocabulary(tokenizer: Tokenizer, new_vocabulary: list[Token]) -> Tok
5050
tokenizer_model.model.vocab.replace_vocabulary(tokens)
5151

5252
new_added_tokens = []
53-
for added_token in tokenizer_model.added_tokens:
53+
for added_token in tokenizer_model.added_tokens.root:
5454
if added_token.content not in {tokenizer_model.unk_token, tokenizer_model.pad_token}:
5555
continue
5656
new_added_tokens.append(added_token)
@@ -70,7 +70,7 @@ def replace_vocabulary(tokenizer: Tokenizer, new_vocabulary: list[Token]) -> Tok
7070
)
7171

7272
pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]
73-
tokenizer_model.added_tokens = _remap_added_tokens(new_added_tokens, pre_tokenized_tokens)
73+
tokenizer_model.added_tokens = AddedTokens(_remap_added_tokens(new_added_tokens, pre_tokenized_tokens))
7474
# Set post processor to None because we don't care about it
7575
tokenizer_model.post_processor = None
7676
# We need to re-set the pad and unk tokens to put the correct indices.
@@ -166,7 +166,7 @@ def _process_internal_tokens(
166166
added_tokens_to_keep: set[str] = {
167167
x for x in (tokenizer_model.pad_token, tokenizer_model.unk_token) if x is not None
168168
}
169-
added_tokens_to_remove = {x.content for x in tokenizer_model.added_tokens} - added_tokens_to_keep
169+
added_tokens_to_remove = {x.content for x in tokenizer_model.added_tokens.root} - added_tokens_to_keep
170170
cleaned_internal_tokens: list[Token] = []
171171

172172
for token in internal_tokens:

0 commit comments

Comments
 (0)