Skip to content

Commit 3f62e19

Browse files
author
moo
committed
typing fixes
1 parent 9e82f35 commit 3f62e19

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

rigging/chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def to_openai(self) -> list[dict[str, t.Any]]:
422422
async def to_tokens(
423423
self,
424424
tokenizer: str,
425-
tokenizer_kwargs: dict[str, t.Any] | None = None,
425+
tokenizer_kwargs: dict[str, t.Any] = {},
426426
*,
427427
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
428428
encode_kwargs: dict[str, t.Any] | None = None,
@@ -538,7 +538,7 @@ def to_openai(self) -> list[list[dict[str, t.Any]]]:
538538
async def to_tokens(
539539
self,
540540
tokenizer: str,
541-
tokenizer_kwargs: dict[str, t.Any] | None = None,
541+
tokenizer_kwargs: dict[str, t.Any] = {},
542542
*,
543543
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
544544
encode_kwargs: dict[str, t.Any] | None = None,

rigging/tokenize/tokenizer.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,44 @@
77

88
from transformers import AutoTokenizer
99

10+
from rigging.logging import logger
1011
from rigging.tokenize.base import Decoder
1112

1213

1314
def get_tokenizer(
14-
model: str | t.Any,
15+
tokenizer_id: str | AutoTokenizer | None,
1516
**tokenizer_kwargs: t.Any,
1617
) -> AutoTokenizer:
1718
"""
1819
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
1920
2021
Args:
21-
model: The model identifier (string) or an already loaded tokenizer.
22+
tokenizer_id: The model identifier (string) or an already loaded tokenizer.
2223
tokenizer_kwargs: Additional keyword arguments for the tokenizer initialization.
2324
2425
Returns:
2526
An instance of `AutoTokenizer`.
2627
"""
27-
if isinstance(model, str):
28-
return AutoTokenizer.from_pretrained(model, **tokenizer_kwargs)
29-
return model
28+
if isinstance(tokenizer_id, str):
29+
try:
30+
tokenizer = AutoTokenizer.from_pretrained(
31+
tokenizer_id,
32+
**tokenizer_kwargs,
33+
)
34+
logger.success(f"Loaded tokenizer for model '{tokenizer_id}'")
35+
36+
except Exception as e: # noqa: BLE001
37+
# Catch all exceptions to handle any issues with loading the tokenizer
38+
logger.error(f"Failed to load tokenizer for model '{tokenizer_id}': {e}")
39+
40+
elif isinstance(tokenizer_id, AutoTokenizer):
41+
return tokenizer
42+
43+
else:
44+
tokenizer = None
45+
logger.error("tokenizer_id must be a string or an instance of AutoTokenizer.")
46+
47+
return tokenizer
3048

3149

3250
def find_in_tokens(

0 commit comments

Comments
 (0)