Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2f0391e
Add base64 encoding for complex parameters in identifier strings
monoxgas May 23, 2025
9c60211
Associate parsed native tool calls as tool_call objects on the Messag…
monoxgas May 23, 2025
849e26d
Add clarity to tool_method docstring
monoxgas May 27, 2025
0d01aa2
Add pipeline transforms. Refactor native tool calling.
monoxgas May 28, 2025
9c58362
Start porting transformers
monoxgas May 30, 2025
537275a
Rework model parsing to auto-wrap internal primitive type fields with…
monoxgas Jun 3, 2025
199b262
More tool refactoring and finishing the transform migration.
monoxgas Jun 3, 2025
4be24e6
Merge branch 'fix/nested-xml-model-parsing' into feat/tokenizer-and-t…
monoxgas Jun 3, 2025
9503d78
Working dump
monoxgas Jun 3, 2025
1049609
Improve error handling in tool calls
monoxgas Jun 3, 2025
1933180
More error handling changes
monoxgas Jun 3, 2025
ba239a5
More working changes
monoxgas Jun 6, 2025
adc58fe
Merge remote-tracking branch 'origin/main' into feat/tokenizer-and-to…
monoxgas Jun 8, 2025
df8559f
Finalizing transforms interface for tools. Refactored message parsed …
monoxgas Jun 10, 2025
976d1e7
some linting and typing fixes
monoxgas Jun 10, 2025
47c9a7f
tokenize chats
Jun 15, 2025
78f47fb
clean yellow squiggle for empty dict as default arg
Jun 15, 2025
49c8b7d
update docs
Jun 15, 2025
9e82f35
update docs
Jun 15, 2025
3f62e19
typing fixes
Jun 15, 2025
27fb6d0
typing fixes #B008
Jun 15, 2025
2a4638f
default tokenizer None
Jun 15, 2025
07270d5
typing
Jun 15, 2025
0ed8ec3
return tokenized chat
Jun 15, 2025
4f7b8b5
clean up arguments. Test with tokenize.ipynb/gsm8k
Jun 15, 2025
04600e8
return AutoTokenizer or None
Jun 15, 2025
d6cb045
rigging logger -> loguru
Jun 15, 2025
d9026e1
get_tokenizer return None
Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
518 changes: 379 additions & 139 deletions docs/api/chat.mdx

Large diffs are not rendered by default.

160 changes: 159 additions & 1 deletion docs/api/data.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,164 @@ def chats_to_elastic_data(
```


</Accordion>

chats\_to\_tokens
-----------------

```python
chats_to_tokens(
chat: Chat | Sequence[Chat],
tokenizer: AutoTokenizer,
*,
apply_chat_template_kwargs: dict[str, Any]
| None = None,
encode_kwargs: dict[str, Any] | None = None,
decode_kwargs: dict[str, Any] | None = None,
) -> TokenizedChat
```

Transform a chat into a tokenized format with structured slices.

**Parameters:**

* **`chat`**
(`Chat | Sequence[Chat]`)
–The chat object to tokenize.
* **`tokenizer`**
(`AutoTokenizer`)
–The tokenizer to use for encoding and decoding.

**Returns:**

* `TokenizedChat`
–A TokenizedChat object containing the tokenized chat data.

<Accordion title="Source code in rigging/data.py" icon="code">
```python
async def chats_to_tokens(
chat: Chat | t.Sequence[Chat],
tokenizer: AutoTokenizer,
*,
apply_chat_template_kwargs: dict[str, t.Any] | None = None,
encode_kwargs: dict[str, t.Any] | None = None,
decode_kwargs: dict[str, t.Any] | None = None,
) -> TokenizedChat:
"""
Transform a chat into a tokenized format with structured slices.

Args:
chat: The chat object to tokenize.
tokenizer: The tokenizer to use for encoding and decoding.

Returns:
A TokenizedChat object containing the tokenized chat data.
"""

apply_chat_template_kwargs = {
"tokenize": False,
**(apply_chat_template_kwargs or {}),
}
encode_kwargs = {
**(encode_kwargs or {}),
}
decode_kwargs = {
"clean_up_tokenization_spaces": False,
**(decode_kwargs or {}),
}

messages = [m.to_openai(compatibility_flags={"content_as_str"}) for m in chat.all]

tools = (
[tool.model_dump() for tool in chat.params.tools]
if chat.params and chat.params.tools
else None
)

chat_text = tokenizer.apply_chat_template(messages, tools=tools, **apply_chat_template_kwargs)
chat_tokens = tokenizer.encode(chat_text, **encode_kwargs)

slices: list[TokenSlice] = []
search_start = 0

# Process messages in order
for message in chat.all:
# Find this message
if not (
match := find_in_tokens(message.content, chat_tokens, tokenizer.decode, 0, search_start)
):
warnings.warn(
f"Warning: Could not find message '{message.content[:50]}...' in chat tokens",
TokenizeWarning,
stacklevel=2,
)
continue

msg_start, msg_end = match
msg_metadata = message.metadata or {}
msg_metadata["role"] = message.role
if message.tool_call_id:
msg_metadata["tool_call_id"] = message.tool_call_id

# Add message slice
slices.append(
TokenSlice(
start=msg_start,
end=msg_end,
type="message",
obj=message,
metadata=msg_metadata,
),
)

# Find parts within this message
message_tokens = chat_tokens[msg_start:msg_end]
part_search_start = 0

# Process message slices in order
for slice_ in message.slices:
part_text = message.content[slice_.slice_]
part_match = find_in_tokens(
part_text,
message_tokens,
tokenizer.decode,
msg_start,
part_search_start,
)
if not part_match:
warnings.warn(
f"Warning: Could not find part '{part_text[:50]}...' in message tokens",
TokenizeWarning,
stacklevel=2,
)
continue

part_start, part_end = part_match
slices.append(
TokenSlice(
start=part_start,
end=part_end,
type=slice_.type,
obj=slice_.obj,
metadata=slice_.metadata,
),
)

# Continue searching after this part
part_search_start = part_end - msg_start

# Continue searching after this message
search_start = msg_end

return TokenizedChat(
text=chat_text,
tokens=chat_tokens,
slices=slices,
obj=chat,
)
```


</Accordion>

df\_to\_chats
Expand Down Expand Up @@ -371,7 +529,7 @@ def elastic_data_to_chats(
# here as we aren't bonded to the underlying rg.Model
# which was the original object. Skipping for now.
for msg in chat.all:
msg.parts = []
msg.slices = []

chats.append(chat)

Expand Down
32 changes: 32 additions & 0 deletions docs/api/error.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ step = step

The pipeline step which cause the depth error.

MessageWarning
--------------

Base class for all message warnings.

This is used to indicate that something unexpected happened during the message processing,
but it is not critical enough to stop the execution.

MessagesExhaustedMaxRoundsError
-------------------------------

Expand Down Expand Up @@ -174,6 +182,14 @@ def __init__(self, content: str):

</Accordion>

PipelineWarning
---------------

Base class for all pipeline warnings.

This is used to indicate that something unexpected happened during the pipeline execution,
but it is not critical enough to stop the execution.

ProcessingError
---------------

Expand Down Expand Up @@ -237,6 +253,14 @@ message = message

The message associated with the stop.

TokenizeWarning
---------------

Base class for all tokenization warnings.

This is used to indicate that something unexpected happened during the tokenization process,
but it is not critical enough to stop the execution.

ToolDefinitionError
-------------------

Expand All @@ -255,6 +279,14 @@ def __init__(self, message: str):

</Accordion>

ToolWarning
-----------

Base class for all tool warnings.

This is used to indicate that something unexpected happened during the tool execution,
but it is not critical enough to stop the execution.

UnknownToolError
----------------

Expand Down
74 changes: 64 additions & 10 deletions docs/api/generator.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ The timeout for the API request.
### tool\_choice

```python
tool_choice: ApiToolChoice | None = None
tool_choice: ToolChoice | None = None
```

The tool choice to be used in the generation.

### tools

```python
tools: list[ApiToolDefinition] | None = None
tools: list[ToolDefinition] | None = None
```

The tools to be used in the generation.
Expand All @@ -165,6 +165,34 @@ top_p: float | None = None

The nucleus sampling probability.

### clone

```python
clone() -> GenerateParams
```

Create a copy of the current parameters instance.

**Returns:**

* `GenerateParams`
–A new instance of GenerateParams with the same values.

<Accordion title="Source code in rigging/generator/base.py" icon="code">
```python
def clone(self) -> "GenerateParams":
"""
Create a copy of the current parameters instance.

Returns:
A new instance of GenerateParams with the same values.
"""
return self.model_copy(deep=True)
```


</Accordion>

### merge\_with

```python
Expand Down Expand Up @@ -1263,6 +1291,16 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
except Exception as e:
raise InvalidModelSpecifiedError(identifier) from e

# Decode any base64 values if present
def decode_value(value: str) -> t.Any:
if value.startswith("base64:"):
with contextlib.suppress(Exception):
decoded = base64.b64decode(value[7:])
return TypeAdapter(t.Any).validate_json(decoded)
return value

kwargs = {k: decode_value(v) for k, v in kwargs.items()}

# See if any of the kwargs would apply to the cls constructor directly
init_signature = inspect.signature(generator_cls)
init_kwargs: dict[str, t.Any] = {
Expand Down Expand Up @@ -1353,23 +1391,39 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
)
identifier = f"{provider}!{generator.model}"

extra_cls_args = generator.model_dump(
identifier_extra = generator.model_dump(
exclude_unset=True,
exclude={"model", "api_key", "params"},
)
if extra_cls_args:
identifier += f",{','.join([f'{k}={v}' for k, v in extra_cls_args.items()])}"

merged_params = generator.params.merge_with(params)
if merged_params.extra:
logger.debug("Extra parameters are not supported in identifiers.")
merged_params.extra = {}

params_dict = merged_params.to_dict()
if params_dict:
if "stop" in params_dict:
params_dict["stop"] = ";".join(params_dict["stop"])
identifier += f",{','.join([f'{k}={v}' for k, v in params_dict.items()])}"
identifier_extra.update(merged_params.to_dict())

# Small correction for stop sequences
if identifier_extra and "stop" in identifier_extra:
identifier_extra["stop"] = ";".join(identifier_extra["stop"])

# Encode any complex values
def encode_value(val: t.Any) -> t.Any:
if isinstance(val, str | int | float | bool):
return val

with contextlib.suppress(Exception):
serialized = TypeAdapter(t.Any).dump_json(val)
encoded = base64.b64encode(serialized).decode()
return f"base64:{encoded}"

return val

identifier_extra = {k: encode_value(v) for k, v in identifier_extra.items()}

# Append them to the identifier
if identifier_extra:
identifier += f",{','.join([f'{k}={v}' for k, v in identifier_extra.items()])}"

return identifier
```
Expand Down
Loading
Loading