Skip to content

Commit d1b726a

Browse files
author
moo
committed
docs
1 parent 1f79691 commit d1b726a

File tree

4 files changed

+47
-36
lines changed

4 files changed

+47
-36
lines changed

docs/api/data.mdx

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ chats\_to\_tokens
258258
```python
259259
chats_to_tokens(
260260
chat: Chat,
261-
tokenizer: AutoTokenizer,
261+
tokenizer: PreTrainedTokenizerBase,
262262
*,
263263
apply_chat_template_kwargs: dict[str, Any]
264264
| None = None,
@@ -275,7 +275,7 @@ Transform a chat into a tokenized format with structured slices.
275275
(`Chat`)
276276
–The chat object to tokenize.
277277
* **`tokenizer`**
278-
(`AutoTokenizer`)
278+
(`PreTrainedTokenizerBase`)
279279
–The tokenizer to use for encoding and decoding.
280280

281281
**Returns:**
@@ -287,7 +287,7 @@ Transform a chat into a tokenized format with structured slices.
287287
```python
288288
async def chats_to_tokens(
289289
chat: Chat,
290-
tokenizer: AutoTokenizer,
290+
tokenizer: "PreTrainedTokenizerBase",
291291
*,
292292
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
293293
encode_kwargs: dict[str, t.Any] | None = None,
@@ -323,8 +323,9 @@ async def chats_to_tokens(
323323
if chat.params and chat.params.tools
324324
else None
325325
)
326+
# the tools above return dict[str, Any], but Transformers expects list[dict[Any, Any]]
326327

327-
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
328+
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs) # type: ignore[arg-type]
328329
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)
329330

330331
slices: list[TokenSlice] = []
@@ -334,7 +335,13 @@ async def chats_to_tokens(
334335
for message in chat.all:
335336
# Find this message
336337
if not (
337-
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
338+
match := find_in_tokens(
339+
message.content,
340+
chat_tokens,
341+
lambda tokens: tokenizer.decode(tokens),
342+
0,
343+
search_start,
344+
)
338345
):
339346
warnings.warn(
340347
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
@@ -370,7 +377,7 @@ async def chats_to_tokens(
370377
part_match = find_in_tokens(
371378
part_text,
372379
message_tokens,
373-
tokenizer.decode,
380+
lambda tokens: tokenizer.decode(tokens),
374381
msg_start,
375382
part_search_start,
376383
)
@@ -399,8 +406,9 @@ async def chats_to_tokens(
399406
# Continue searching after this message
400407
search_start = msg_end
401408

409+
# we ask for a string by default in apply_chat_template_kwargs with the tokenize=False
402410
return TokenizedChat(
403-
text=chat_text,
411+
text=chat_text, # type: ignore[arg-type]
404412
tokens=chat_tokens,
405413
slices=slices,
406414
obj=chat,

docs/api/tokenize.mdx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ get\_tokenizer
1212
```python
1313
get_tokenizer(
1414
tokenizer_id: str, **tokenizer_kwargs: Any
15-
) -> AutoTokenizer | None
15+
) -> t.Any
1616
```
1717

1818
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
@@ -30,15 +30,15 @@ Get the tokenizer from transformers model identifier, or from an already loaded
3030

3131
**Returns:**
3232

33-
* `AutoTokenizer | None`
33+
* `Any`
3434
–An instance of `AutoTokenizer`.
3535

3636
<Accordion title="Source code in rigging/tokenize/tokenizer.py" icon="code">
3737
```python
3838
def get_tokenizer(
3939
tokenizer_id: str,
4040
**tokenizer_kwargs: t.Any,
41-
) -> AutoTokenizer | None:
41+
) -> t.Any:
4242
"""
4343
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
4444
@@ -49,18 +49,20 @@ def get_tokenizer(
4949
Returns:
5050
An instance of `AutoTokenizer`.
5151
"""
52-
tokenizer: AutoTokenizer | None = None
53-
5452
try:
53+
from transformers import AutoTokenizer
54+
5555
tokenizer = AutoTokenizer.from_pretrained(
5656
tokenizer_id,
5757
**tokenizer_kwargs,
5858
)
5959
logger.success(f"Loaded tokenizer for model '{tokenizer_id}'")
6060

61-
except Exception as e: # noqa: BLE001
61+
except Exception as e:
6262
# Catch all exceptions to handle any issues with loading the tokenizer
63-
logger.error(f"Failed to load tokenizer for model '{tokenizer_id}': {e}")
63+
raise RuntimeError(
64+
f"Failed to load tokenizer for model '{tokenizer_id}': {e}",
65+
) from e
6466

6567
return tokenizer
6668
```

rigging/data.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
import pandas as pd
1313
from elastic_transport import ObjectApiResponse
1414
from mypy_boto3_s3 import S3Client
15-
from transformers import AutoTokenizer
1615

1716
from rigging.chat import Chat
1817
from rigging.error import TokenizeWarning
1918
from rigging.message import Message
2019
from rigging.tokenize import find_in_tokens
2120
from rigging.tokenize.base import TokenizedChat, TokenSlice
2221

22+
if t.TYPE_CHECKING:
23+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
24+
2325

2426
def flatten_chats(chats: Chat | t.Sequence[Chat]) -> list[dict[t.Any, t.Any]]:
2527
"""
@@ -295,7 +297,7 @@ async def chats_to_elastic(
295297

296298
async def chats_to_tokens(
297299
chat: Chat,
298-
tokenizer: AutoTokenizer,
300+
tokenizer: "PreTrainedTokenizerBase",
299301
*,
300302
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
301303
encode_kwargs: dict[str, t.Any] | None = None,
@@ -331,8 +333,9 @@ async def chats_to_tokens(
331333
if chat.params and chat.params.tools
332334
else None
333335
)
336+
# the tools above return dict[str, Any], but Transformers expects list[dict[Any, Any]]
334337

335-
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
338+
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs) # type: ignore[arg-type]
336339
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)
337340

338341
slices: list[TokenSlice] = []
@@ -342,7 +345,13 @@ async def chats_to_tokens(
342345
for message in chat.all:
343346
# Find this message
344347
if not (
345-
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
348+
match := find_in_tokens(
349+
message.content,
350+
chat_tokens,
351+
lambda tokens: tokenizer.decode(tokens),
352+
0,
353+
search_start,
354+
)
346355
):
347356
warnings.warn(
348357
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
@@ -378,7 +387,7 @@ async def chats_to_tokens(
378387
part_match = find_in_tokens(
379388
part_text,
380389
message_tokens,
381-
tokenizer.decode,
390+
lambda tokens: tokenizer.decode(tokens),
382391
msg_start,
383392
part_search_start,
384393
)
@@ -407,8 +416,9 @@ async def chats_to_tokens(
407416
# Continue searching after this message
408417
search_start = msg_end
409418

419+
# we ask for a string by default in apply_chat_template_kwargs with the tokenize=False
410420
return TokenizedChat(
411-
text=chat_text,
421+
text=chat_text, # type: ignore[arg-type]
412422
tokens=chat_tokens,
413423
slices=slices,
414424
obj=chat,

rigging/tokenize/tokenizer.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,14 @@
11
import typing as t
2-
from typing import TYPE_CHECKING
3-
4-
if TYPE_CHECKING:
5-
from transformers import AutoTokenizer
62

73
from loguru import logger
84

9-
try:
10-
from transformers import AutoTokenizer
11-
except ImportError:
12-
raise ModuleNotFoundError(
13-
"Please install the `transformers` package to use this module.",
14-
) from None
15-
165
from rigging.tokenize.base import Decoder
176

187

198
def get_tokenizer(
209
tokenizer_id: str,
2110
**tokenizer_kwargs: t.Any,
22-
) -> AutoTokenizer | None:
11+
) -> t.Any:
2312
"""
2413
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
2514
@@ -30,18 +19,20 @@ def get_tokenizer(
3019
Returns:
3120
An instance of `AutoTokenizer`.
3221
"""
33-
tokenizer: AutoTokenizer | None = None
34-
3522
try:
23+
from transformers import AutoTokenizer
24+
3625
tokenizer = AutoTokenizer.from_pretrained(
3726
tokenizer_id,
3827
**tokenizer_kwargs,
3928
)
4029
logger.success(f"Loaded tokenizer for model '{tokenizer_id}'")
4130

42-
except Exception as e: # noqa: BLE001
31+
except Exception as e:
4332
# Catch all exceptions to handle any issues with loading the tokenizer
44-
logger.error(f"Failed to load tokenizer for model '{tokenizer_id}': {e}")
33+
raise RuntimeError(
34+
f"Failed to load tokenizer for model '{tokenizer_id}': {e}",
35+
) from e
4536

4637
return tokenizer
4738

0 commit comments

Comments
 (0)