Skip to content

Commit 09a415d

Browse files
authored
Merge pull request #129 from dreadnode/feat/tokenizer-and-tool-tracking
feat: Tokenizer and tool tracking [WIP] unresolved: typing from Transformers. And get_tokenizer could return None, causing chats_to_tokens to fail which is expected. It it not expected that chats_to_tokens returns None. test: tokenize.ipynb
2 parents e78374a + d9026e1 commit 09a415d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+5922
-2576
lines changed

docs/api/chat.mdx

Lines changed: 379 additions & 139 deletions
Large diffs are not rendered by default.

docs/api/data.mdx

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,164 @@ def chats_to_elastic_data(
250250
```
251251

252252

253+
</Accordion>
254+
255+
chats\_to\_tokens
256+
-----------------
257+
258+
```python
259+
chats_to_tokens(
260+
chat: Chat | Sequence[Chat],
261+
tokenizer: AutoTokenizer,
262+
*,
263+
apply_chat_template_kwargs: dict[str, Any]
264+
| None = None,
265+
encode_kwargs: dict[str, Any] | None = None,
266+
decode_kwargs: dict[str, Any] | None = None,
267+
) -> TokenizedChat
268+
```
269+
270+
Transform a chat into a tokenized format with structured slices.
271+
272+
**Parameters:**
273+
274+
* **`chat`**
275+
(`Chat | Sequence[Chat]`)
276+
–The chat object to tokenize.
277+
* **`tokenizer`**
278+
(`AutoTokenizer`)
279+
–The tokenizer to use for encoding and decoding.
280+
281+
**Returns:**
282+
283+
* `TokenizedChat`
284+
–A TokenizedChat object containing the tokenized chat data.
285+
286+
<Accordion title="Source code in rigging/data.py" icon="code">
287+
```python
288+
async def chats_to_tokens(
289+
chat: Chat | t.Sequence[Chat],
290+
tokenizer: AutoTokenizer,
291+
*,
292+
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
293+
encode_kwargs: dict[str, t.Any] | None = None,
294+
decode_kwargs: dict[str, t.Any] | None = None,
295+
) -> TokenizedChat:
296+
"""
297+
Transform a chat into a tokenized format with structured slices.
298+
299+
Args:
300+
chat: The chat object to tokenize.
301+
tokenizer: The tokenizer to use for encoding and decoding.
302+
303+
Returns:
304+
A TokenizedChat object containing the tokenized chat data.
305+
"""
306+
307+
apply_chat_template_kwargs = {
308+
"tokenize": False,
309+
**(apply_chat_template_kwargs or {}),
310+
}
311+
encode_kwargs = {
312+
**(encode_kwargs or {}),
313+
}
314+
decode_kwargs = {
315+
"clean_up_tokenization_spaces": False,
316+
**(decode_kwargs or {}),
317+
}
318+
319+
messages = [m.to_openai(compatibility_flags={"content_as_str"}) for m in chat.all]
320+
321+
tools = (
322+
[tool.model_dump() for tool in chat.params.tools]
323+
if chat.params and chat.params.tools
324+
else None
325+
)
326+
327+
chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
328+
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)
329+
330+
slices: list[TokenSlice] = []
331+
search_start = 0
332+
333+
# Process messages in order
334+
for message in chat.all:
335+
# Find this message
336+
if not (
337+
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
338+
):
339+
warnings.warn(
340+
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
341+
TokenizeWarning,
342+
stacklevel=2,
343+
)
344+
continue
345+
346+
msg_start, msg_end = match
347+
msg_metadata = message.metadata or {}
348+
msg_metadata["role"] = message.role
349+
if message.tool_call_id:
350+
msg_metadata["tool_call_id"] = message.tool_call_id
351+
352+
# Add message slice
353+
slices.append(
354+
TokenSlice(
355+
start=msg_start,
356+
end=msg_end,
357+
type="message",
358+
obj=message,
359+
metadata=msg_metadata,
360+
),
361+
)
362+
363+
# Find parts within this message
364+
message_tokens = chat_tokens[msg_start:msg_end]
365+
part_search_start = 0
366+
367+
# Process message slices in order
368+
for slice_ in message.slices:
369+
part_text = message.content[slice_.slice_]
370+
part_match = find_in_tokens(
371+
part_text,
372+
message_tokens,
373+
tokenizer.decode,
374+
msg_start,
375+
part_search_start,
376+
)
377+
if not part_match:
378+
warnings.warn(
379+
f"Warning: Could not find part '{part_text[:50]}...' in message tokens",
380+
TokenizeWarning,
381+
stacklevel=2,
382+
)
383+
continue
384+
385+
part_start, part_end = part_match
386+
slices.append(
387+
TokenSlice(
388+
start=part_start,
389+
end=part_end,
390+
type=slice_.type,
391+
obj=slice_.obj,
392+
metadata=slice_.metadata,
393+
),
394+
)
395+
396+
# Continue searching after this part
397+
part_search_start = part_end - msg_start
398+
399+
# Continue searching after this message
400+
search_start = msg_end
401+
402+
return TokenizedChat(
403+
text=chat_text,
404+
tokens=chat_tokens,
405+
slices=slices,
406+
obj=chat,
407+
)
408+
```
409+
410+
253411
</Accordion>
254412

255413
df\_to\_chats
@@ -371,7 +529,7 @@ def elastic_data_to_chats(
371529
# here as we aren't bonded to the underlying rg.Model
372530
# which was the original object. Skipping for now.
373531
for msg in chat.all:
374-
msg.parts = []
532+
msg.slices = []
375533

376534
chats.append(chat)
377535

docs/api/error.mdx

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ step = step
126126

127127
The pipeline step which cause the depth error.
128128

129+
MessageWarning
130+
--------------
131+
132+
Base class for all message warnings.
133+
134+
This is used to indicate that something unexpected happened during the message processing,
135+
but it is not critical enough to stop the execution.
136+
129137
MessagesExhaustedMaxRoundsError
130138
-------------------------------
131139

@@ -174,6 +182,14 @@ def __init__(self, content: str):
174182

175183
</Accordion>
176184

185+
PipelineWarning
186+
---------------
187+
188+
Base class for all pipeline warnings.
189+
190+
This is used to indicate that something unexpected happened during the pipeline execution,
191+
but it is not critical enough to stop the execution.
192+
177193
ProcessingError
178194
---------------
179195

@@ -237,6 +253,14 @@ message = message
237253

238254
The message associated with the stop.
239255

256+
TokenizeWarning
257+
---------------
258+
259+
Base class for all tokenization warnings.
260+
261+
This is used to indicate that something unexpected happened during the tokenization process,
262+
but it is not critical enough to stop the execution.
263+
240264
ToolDefinitionError
241265
-------------------
242266

@@ -255,6 +279,14 @@ def __init__(self, message: str):
255279

256280
</Accordion>
257281

282+
ToolWarning
283+
-----------
284+
285+
Base class for all tool warnings.
286+
287+
This is used to indicate that something unexpected happened during the tool execution,
288+
but it is not critical enough to stop the execution.
289+
258290
UnknownToolError
259291
----------------
260292

docs/api/generator.mdx

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ The timeout for the API request.
136136
### tool\_choice
137137

138138
```python
139-
tool_choice: ApiToolChoice | None = None
139+
tool_choice: ToolChoice | None = None
140140
```
141141

142142
The tool choice to be used in the generation.
143143

144144
### tools
145145

146146
```python
147-
tools: list[ApiToolDefinition] | None = None
147+
tools: list[ToolDefinition] | None = None
148148
```
149149

150150
The tools to be used in the generation.
@@ -165,6 +165,34 @@ top_p: float | None = None
165165

166166
The nucleus sampling probability.
167167

168+
### clone
169+
170+
```python
171+
clone() -> GenerateParams
172+
```
173+
174+
Create a copy of the current parameters instance.
175+
176+
**Returns:**
177+
178+
* `GenerateParams`
179+
–A new instance of GenerateParams with the same values.
180+
181+
<Accordion title="Source code in rigging/generator/base.py" icon="code">
182+
```python
183+
def clone(self) -> "GenerateParams":
184+
"""
185+
Create a copy of the current parameters instance.
186+
187+
Returns:
188+
A new instance of GenerateParams with the same values.
189+
"""
190+
return self.model_copy(deep=True)
191+
```
192+
193+
194+
</Accordion>
195+
168196
### merge\_with
169197

170198
```python
@@ -1263,6 +1291,16 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
12631291
except Exception as e:
12641292
raise InvalidModelSpecifiedError(identifier) from e
12651293

1294+
# Decode any base64 values if present
1295+
def decode_value(value: str) -> t.Any:
1296+
if value.startswith("base64:"):
1297+
with contextlib.suppress(Exception):
1298+
decoded = base64.b64decode(value[7:])
1299+
return TypeAdapter(t.Any).validate_json(decoded)
1300+
return value
1301+
1302+
kwargs = {k: decode_value(v) for k, v in kwargs.items()}
1303+
12661304
# See if any of the kwargs would apply to the cls constructor directly
12671305
init_signature = inspect.signature(generator_cls)
12681306
init_kwargs: dict[str, t.Any] = {
@@ -1353,23 +1391,39 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
13531391
)
13541392
identifier = f"{provider}!{generator.model}"
13551393

1356-
extra_cls_args = generator.model_dump(
1394+
identifier_extra = generator.model_dump(
13571395
exclude_unset=True,
13581396
exclude={"model", "api_key", "params"},
13591397
)
1360-
if extra_cls_args:
1361-
identifier += f",{','.join([f'{k}={v}' for k, v in extra_cls_args.items()])}"
13621398

13631399
merged_params = generator.params.merge_with(params)
13641400
if merged_params.extra:
13651401
logger.debug("Extra parameters are not supported in identifiers.")
13661402
merged_params.extra = {}
13671403

1368-
params_dict = merged_params.to_dict()
1369-
if params_dict:
1370-
if "stop" in params_dict:
1371-
params_dict["stop"] = ";".join(params_dict["stop"])
1372-
identifier += f",{','.join([f'{k}={v}' for k, v in params_dict.items()])}"
1404+
identifier_extra.update(merged_params.to_dict())
1405+
1406+
# Small correction for stop sequences
1407+
if identifier_extra and "stop" in identifier_extra:
1408+
identifier_extra["stop"] = ";".join(identifier_extra["stop"])
1409+
1410+
# Encode any complex values
1411+
def encode_value(val: t.Any) -> t.Any:
1412+
if isinstance(val, str | int | float | bool):
1413+
return val
1414+
1415+
with contextlib.suppress(Exception):
1416+
serialized = TypeAdapter(t.Any).dump_json(val)
1417+
encoded = base64.b64encode(serialized).decode()
1418+
return f"base64:{encoded}"
1419+
1420+
return val
1421+
1422+
identifier_extra = {k: encode_value(v) for k, v in identifier_extra.items()}
1423+
1424+
# Append them to the identifier
1425+
if identifier_extra:
1426+
identifier += f",{','.join([f'{k}={v}' for k, v in identifier_extra.items()])}"
13731427

13741428
return identifier
13751429
```

0 commit comments

Comments
 (0)