Skip to content

Commit fa45e93

Browse files
committed
fix: add skeletoken
1 parent 4b81fad commit fa45e93

File tree

5 files changed

+45
-31
lines changed

5 files changed

+45
-31
lines changed

model2vec/tokenizer/tokenizer.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,7 @@ def replace_vocabulary(tokenizer: Tokenizer, new_vocabulary: list[Token]) -> Tok
5656
new_added_tokens.append(added_token)
5757
for token in new_vocabulary:
5858
if token.is_multiword and token.form not in {tokenizer_model.unk_token, tokenizer_model.pad_token}:
59-
token_id = tokenizer_model.model.vocab[token.form]
60-
new_added_tokens.append(
61-
AddedToken(
62-
content=token.form,
63-
single_word=False,
64-
lstrip=True,
65-
rstrip=True,
66-
normalized=True,
67-
special=False,
68-
id=token_id,
69-
)
70-
)
59+
tokenizer_model.add_addedtoken(token.form, normalized=True, single_word=False)
7160

7261
pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]
7362
tokenizer_model.added_tokens = AddedTokens(_remap_added_tokens(new_added_tokens, pre_tokenized_tokens))
@@ -234,32 +223,32 @@ def turn_tokens_into_ids(tokens: list[Token], tokenizer: Tokenizer) -> list[list
234223
"""
235224
prefix, suffix = find_eos_bos(tokenizer)
236225

237-
prefix_id, suffix_id = None, None
226+
prefix_ids, suffix_ids = None, None
238227
vocab = tokenizer.get_vocab()
239228
if prefix is not None:
240-
prefix_id = vocab[prefix]
229+
prefix_ids = [vocab[token] for token in prefix]
241230
if suffix is not None:
242-
suffix_id = vocab[suffix]
231+
suffix_ids = [vocab[token] for token in suffix]
243232

244233
token_ids: list[list[int]] = []
245234
for token in tokens:
246235
token_sequence = []
247-
if prefix_id is not None:
248-
token_sequence.append(prefix_id)
236+
if prefix_ids is not None:
237+
token_sequence.extend(prefix_ids)
249238
if token.is_internal:
250239
token_id = vocab[token.form]
251240
token_sequence.append(token_id)
252241
else:
253242
token_sequence.extend(tokenizer.encode(token.form).ids)
254-
if suffix_id is not None:
255-
token_sequence.append(suffix_id)
243+
if suffix_ids is not None:
244+
token_sequence.extend(suffix_ids)
256245

257246
token_ids.append(token_sequence)
258247

259248
return token_ids
260249

261250

262-
def find_eos_bos(tokenizer: Tokenizer) -> tuple[str | None, str | None]:
251+
def find_eos_bos(tokenizer: Tokenizer) -> tuple[list[str] | None, list[str] | None]:
263252
"""Finds the eos and bos tokens for a tokenizer."""
264253
model = TokenizerModel.from_tokenizer(tokenizer)
265254
return model.bos, model.eos

model2vec/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
import json
54
import logging
65
import re
76
from importlib import import_module
87
from importlib.metadata import metadata
9-
from pathlib import Path
108
from typing import Any, Iterator, Protocol, cast
119

1210
import numpy as np
13-
import safetensors
1411
from joblib import Parallel
15-
from tokenizers import Tokenizer
1612
from tqdm import tqdm
1713

1814
logger = logging.getLogger(__name__)
@@ -78,12 +74,15 @@ def get_package_extras(package: str, extra: str) -> Iterator[str]:
7874
found_extra = rest[0].split("==")[-1].strip(" \"'")
7975
if found_extra == extra:
8076
prefix, *_ = _DIVIDERS.split(name)
77+
prefix = prefix.split("@")[0].strip()
8178
yield prefix.strip()
8279

8380

8481
def importable(module: str, extra: str) -> None:
8582
"""Check if a module is importable."""
8683
module = dict(_MODULE_MAP).get(module, module)
84+
# Allows this to work with git installed modules.
85+
module = module.split("@")[0].strip()
8786
try:
8887
import_module(module)
8988
except ImportError:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dev = [
6060
"ruff",
6161
]
6262

63-
distill = ["torch", "transformers", "scikit-learn", "skeletoken @ git+https://github.com/stephantul/skeletoken.git"]
63+
distill = ["torch", "transformers", "scikit-learn", "skeletoken"]
6464
onnx = ["onnx", "torch"]
6565
# train also installs inference
6666
train = ["torch", "lightning", "scikit-learn", "skops"]

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_importable() -> None:
7272
def test_get_package_extras() -> None:
7373
"""Test package extras."""
7474
extras = set(get_package_extras("model2vec", "distill"))
75-
assert extras == {"torch", "transformers", "scikit-learn"}
75+
assert extras == {"torch", "transformers", "scikit-learn", "skeletoken"}
7676

7777

7878
def test_get_package_extras_empty() -> None:

uv.lock

Lines changed: 31 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)