Skip to content

Commit 7948cb9

Browse files
authored
Merge pull request #159 from dreadnode/fix/transformers_type_check
fix: typing check for transformers
2 parents 592243e + d1b726a commit 7948cb9

File tree

4 files changed

+51
-35
lines changed

4 files changed

+51
-35
lines changed

docs/api/data.mdx

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ chats\_to\_tokens
257257

258258
```python
259259
chats_to_tokens(
260-
chat: Chat | None,
261-
tokenizer: AutoTokenizer,
260+
chat: Chat,
261+
tokenizer: PreTrainedTokenizerBase,
262262
*,
263263
apply_chat_template_kwargs: dict[str, Any]
264264
| None = None,
@@ -272,10 +272,10 @@ Transform a chat into a tokenized format with structured slices.
272272
**Parameters:**
273273

274274
* **`chat`**
275-
(`Chat | None`)
275+
(`Chat`)
276276
–The chat object to tokenize.
277277
* **`tokenizer`**
278-
(`AutoTokenizer`)
278+
(`PreTrainedTokenizerBase`)
279279
–The tokenizer to use for encoding and decoding.
280280

281281
**Returns:**
@@ -286,8 +286,8 @@ Transform a chat into a tokenized format with structured slices.
286286
<Accordion title="Source code in rigging/data.py" icon="code">
287287
```python
288288
async def chats_to_tokens(
289-
chat: Chat | None,
290-
tokenizer: AutoTokenizer,
289+
chat: Chat,
290+
tokenizer: "PreTrainedTokenizerBase",
291291
*,
292292
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
293293
encode_kwargs: dict[str, t.Any] | None = None,
@@ -323,8 +323,9 @@ async def chats_to_tokens(
323323
if chat.params and chat.params.tools
324324
else None
325325
)
326+
# the tools above return dict[str, Any], but Transformers expects list[dict[Any, Any]]
326327

327-
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
328+
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs) # type: ignore[arg-type]
328329
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)
329330

330331
slices: list[TokenSlice] = []
@@ -334,7 +335,13 @@ async def chats_to_tokens(
334335
for message in chat.all:
335336
# Find this message
336337
if not (
337-
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
338+
match := find_in_tokens(
339+
message.content,
340+
chat_tokens,
341+
lambda tokens: tokenizer.decode(tokens),
342+
0,
343+
search_start,
344+
)
338345
):
339346
warnings.warn(
340347
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
@@ -370,7 +377,7 @@ async def chats_to_tokens(
370377
part_match = find_in_tokens(
371378
part_text,
372379
message_tokens,
373-
tokenizer.decode,
380+
lambda tokens: tokenizer.decode(tokens),
374381
msg_start,
375382
part_search_start,
376383
)
@@ -399,8 +406,9 @@ async def chats_to_tokens(
399406
# Continue searching after this message
400407
search_start = msg_end
401408

409+
# we ask for a string by default in apply_chat_template_kwargs with the tokenize=False
402410
return TokenizedChat(
403-
text=chat_text,
411+
text=chat_text, # type: ignore[arg-type]
404412
tokens=chat_tokens,
405413
slices=slices,
406414
obj=chat,

docs/api/tokenize.mdx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ get\_tokenizer
1212
```python
1313
get_tokenizer(
1414
tokenizer_id: str, **tokenizer_kwargs: Any
15-
) -> AutoTokenizer | None
15+
) -> t.Any
1616
```
1717

1818
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
@@ -30,15 +30,15 @@ Get the tokenizer from transformers model identifier, or from an already loaded
3030

3131
**Returns:**
3232

33-
* `AutoTokenizer | None`
33+
* `Any`
3434
–An instance of `AutoTokenizer`.
3535

3636
<Accordion title="Source code in rigging/tokenize/tokenizer.py" icon="code">
3737
```python
3838
def get_tokenizer(
3939
tokenizer_id: str,
4040
**tokenizer_kwargs: t.Any,
41-
) -> AutoTokenizer | None:
41+
) -> t.Any:
4242
"""
4343
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
4444
@@ -49,18 +49,20 @@ def get_tokenizer(
4949
Returns:
5050
An instance of `AutoTokenizer`.
5151
"""
52-
tokenizer: AutoTokenizer | None = None
53-
5452
try:
53+
from transformers import AutoTokenizer
54+
5555
tokenizer = AutoTokenizer.from_pretrained(
5656
tokenizer_id,
5757
**tokenizer_kwargs,
5858
)
5959
logger.success(f"Loaded tokenizer for model '{tokenizer_id}'")
6060

61-
except Exception as e: # noqa: BLE001
61+
except Exception as e:
6262
# Catch all exceptions to handle any issues with loading the tokenizer
63-
logger.error(f"Failed to load tokenizer for model '{tokenizer_id}': {e}")
63+
raise RuntimeError(
64+
f"Failed to load tokenizer for model '{tokenizer_id}': {e}",
65+
) from e
6466

6567
return tokenizer
6668
```

rigging/data.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
import pandas as pd
1313
from elastic_transport import ObjectApiResponse
1414
from mypy_boto3_s3 import S3Client
15-
from transformers import AutoTokenizer
1615

1716
from rigging.chat import Chat
1817
from rigging.error import TokenizeWarning
1918
from rigging.message import Message
2019
from rigging.tokenize import find_in_tokens
2120
from rigging.tokenize.base import TokenizedChat, TokenSlice
2221

22+
if t.TYPE_CHECKING:
23+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
24+
2325

2426
def flatten_chats(chats: Chat | t.Sequence[Chat]) -> list[dict[t.Any, t.Any]]:
2527
"""
@@ -294,8 +296,8 @@ async def chats_to_elastic(
294296

295297

296298
async def chats_to_tokens(
297-
chat: Chat | None,
298-
tokenizer: AutoTokenizer,
299+
chat: Chat,
300+
tokenizer: "PreTrainedTokenizerBase",
299301
*,
300302
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
301303
encode_kwargs: dict[str, t.Any] | None = None,
@@ -331,8 +333,9 @@ async def chats_to_tokens(
331333
if chat.params and chat.params.tools
332334
else None
333335
)
336+
# the tools above return dict[str, Any], but Transformers expects list[dict[Any, Any]]
334337

335-
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
338+
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs) # type: ignore[arg-type]
336339
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)
337340

338341
slices: list[TokenSlice] = []
@@ -342,7 +345,13 @@ async def chats_to_tokens(
342345
for message in chat.all:
343346
# Find this message
344347
if not (
345-
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
348+
match := find_in_tokens(
349+
message.content,
350+
chat_tokens,
351+
lambda tokens: tokenizer.decode(tokens),
352+
0,
353+
search_start,
354+
)
346355
):
347356
warnings.warn(
348357
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
@@ -378,7 +387,7 @@ async def chats_to_tokens(
378387
part_match = find_in_tokens(
379388
part_text,
380389
message_tokens,
381-
tokenizer.decode,
390+
lambda tokens: tokenizer.decode(tokens),
382391
msg_start,
383392
part_search_start,
384393
)
@@ -407,8 +416,9 @@ async def chats_to_tokens(
407416
# Continue searching after this message
408417
search_start = msg_end
409418

419+
# we ask for a string by default in apply_chat_template_kwargs with the tokenize=False
410420
return TokenizedChat(
411-
text=chat_text,
421+
text=chat_text, # type: ignore[arg-type]
412422
tokens=chat_tokens,
413423
slices=slices,
414424
obj=chat,

rigging/tokenize/tokenizer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1-
import importlib.util
21
import typing as t
32

4-
if importlib.util.find_spec("transformers") is None:
5-
raise ModuleNotFoundError("Please install the `transformers` package to use this module.")
6-
7-
83
from loguru import logger
9-
from transformers import AutoTokenizer
104

115
from rigging.tokenize.base import Decoder
126

137

148
def get_tokenizer(
159
tokenizer_id: str,
1610
**tokenizer_kwargs: t.Any,
17-
) -> AutoTokenizer | None:
11+
) -> t.Any:
1812
"""
1913
Get the tokenizer from transformers model identifier, or from an already loaded tokenizer.
2014
@@ -25,18 +19,20 @@ def get_tokenizer(
2519
Returns:
2620
An instance of `AutoTokenizer`.
2721
"""
28-
tokenizer: AutoTokenizer | None = None
29-
3022
try:
23+
from transformers import AutoTokenizer
24+
3125
tokenizer = AutoTokenizer.from_pretrained(
3226
tokenizer_id,
3327
**tokenizer_kwargs,
3428
)
3529
logger.success(f"Loaded tokenizer for model '{tokenizer_id}'")
3630

37-
except Exception as e: # noqa: BLE001
31+
except Exception as e:
3832
# Catch all exceptions to handle any issues with loading the tokenizer
39-
logger.error(f"Failed to load tokenizer for model '{tokenizer_id}': {e}")
33+
raise RuntimeError(
34+
f"Failed to load tokenizer for model '{tokenizer_id}': {e}",
35+
) from e
4036

4137
return tokenizer
4238

0 commit comments

Comments
 (0)