Skip to content

Commit 9089d25

Browse files
committed
Add pythonic tool transforms for llama and gemma. Version to 3.3.4.
1 parent 3c983ba commit 9089d25

File tree

13 files changed

+669
-72
lines changed

13 files changed

+669
-72
lines changed

docs/api/tools.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ToolMode = Literal[
2121
"json",
2222
"json-in-xml",
2323
"json-with-tag",
24+
"pythonic",
2425
]
2526
```
2627

@@ -32,6 +33,7 @@ How tool calls are handled.
3233
* `json`: Tool calls are parsed as raw name/arg JSON anywhere in assistant message content.
3334
* `json-in-xml`: Tool calls are parsed using JSON for arguments, and XML for everything else.
3435
* `json-with-tag`: Tool calls are parsed as name/arg JSON structures inside an XML tag to identify it.
36+
* `pythonic`: Tool calls are parsed as pythonic function call syntax.
3537

3638
Tool
3739
----

docs/api/transform.mdx

Lines changed: 272 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ tool responses are converted to user messages with a "tool\_response" type.
5555

5656
See `make_tools_to_json_transform` for more details and more behavior options.
5757

58+
tools\_to\_pythonic\_transform
59+
------------------------------
60+
61+
```python
62+
tools_to_pythonic_transform = (
63+
make_tools_to_pythonic_transform()
64+
)
65+
```
66+
67+
A transform that converts tool calls to a pythonic list format.
68+
69+
See `make_tools_to_pythonic_transform` for more details and more behavior options.
70+
5871
PostTransform
5972
-------------
6073

@@ -159,6 +172,8 @@ def get_transform(identifier: str) -> Transform:
159172
return tools_to_json_in_xml_transform
160173
case "json-with-tag":
161174
return tools_to_json_with_tag_transform
175+
case "pythonic":
176+
return tools_to_pythonic_transform
162177
case _:
163178
raise ValueError(f"Unknown transform identifier: {identifier}")
164179
```
@@ -289,20 +304,31 @@ def make_tools_to_json_transform( # noqa: PLR0915
289304

290305
# Render all our existing tool calls as JSON in the content
291306

292-
for message in messages:
293-
if tool_responses_as_user_messages and message.role == "tool":
294-
message.replace_with_slice(
295-
tool_response_cls(
296-
id=message.tool_call_id or "",
297-
result=message.content,
298-
),
299-
"tool_response",
300-
metadata={"id": message.tool_call_id or ""},
301-
)
302-
message.role = "user"
303-
message.tool_call_id = None
307+
updated_messages: list[Message] = []
308+
309+
for is_tool_group, message_group in itertools.groupby(
310+
messages, key=lambda m: tool_responses_as_user_messages and m.role == "tool"
311+
):
312+
if is_tool_group:
313+
user_message = Message(role="user", content="")
314+
for message in message_group:
315+
user_message.append_slice(
316+
tool_response_cls(
317+
id=message.tool_call_id or "",
318+
result=message.content,
319+
),
320+
"tool_response",
321+
metadata={"id": message.tool_call_id or ""},
322+
)
323+
updated_messages.append(user_message)
324+
continue
304325

305-
elif message.tool_calls:
326+
for message in message_group:
327+
if not message.tool_calls:
328+
updated_messages.append(message)
329+
continue
330+
331+
updated_message = message.clone()
306332
for tool_call in message.tool_calls:
307333
content: str | Model
308334
match mode:
@@ -321,14 +347,15 @@ def make_tools_to_json_transform( # noqa: PLR0915
321347
content=f'{{"name": "{tool_call.function.name}", "arguments": "{tool_call.function.arguments}"}}',
322348
)
323349

324-
message.append_slice(
350+
updated_message.append_slice(
325351
content,
326352
"tool_call",
327353
obj=tool_call,
328354
metadata={"id": tool_call.id or ""},
329355
)
330356

331-
message.tool_calls = None
357+
updated_message.tool_calls = None
358+
updated_messages.append(updated_message)
332359

333360
# Save any existing tool params
334361

@@ -434,13 +461,22 @@ def make_tools_to_json_transform( # noqa: PLR0915
434461

435462
# Convert our tool responses
436463

437-
for message in [m for m in chat.all if m.role == "user"]:
438-
if (tool_response := message.try_parse(tool_response_cls)) is None:
464+
updated_messages = []
465+
for message in messages:
466+
if message.role != "user" or not (
467+
tool_responses := message.try_parse_set(tool_response_cls)
468+
):
469+
updated_messages.append(message)
439470
continue
440471

441-
message.content = tool_response.result
442-
message.tool_call_id = tool_response.id
443-
message.role = "tool"
472+
for tool_response in tool_responses:
473+
updated_messages.append( # noqa: PERF401
474+
Message(
475+
role="tool",
476+
content=tool_response.result,
477+
tool_call_id=tool_response.id,
478+
)
479+
)
444480

445481
# Restore the params
446482

@@ -450,16 +486,230 @@ def make_tools_to_json_transform( # noqa: PLR0915
450486

451487
# Strip the system prompt content
452488

453-
chat.messages = strip_system_content(chat.messages, system_prompt)
489+
chat.messages = strip_system_content(updated_messages, system_prompt)
454490

455491
return chat
456492

457-
return messages, params, json_to_tools_transform
493+
return updated_messages, params, json_to_tools_transform
458494

459495
return tools_to_json_transform
460496
```
461497

462498

499+
</Accordion>
500+
501+
make\_tools\_to\_pythonic\_transform
502+
------------------------------------
503+
504+
```python
505+
make_tools_to_pythonic_transform(
506+
*,
507+
system_tool_prompt: Callable[
508+
[list[ToolDefinition]], str
509+
]
510+
| str
511+
| None = None,
512+
tool_responses_as_user_messages: bool = True,
513+
tool_response_tag: str = "tool-response",
514+
) -> Transform
515+
```
516+
517+
Create a transform that converts tool calls to a pythonic list format.
518+
519+
This transform will:
520+
1. Inject a system prompt with tool definitions serialized as JSON.
521+
2. Convert existing tool calls in messages to `[my_func(arg=...)]` format.
522+
3. Convert tool result messages into `<tool-response>` blocks in a user message (optional).
523+
4. In the post-transform, parse the model's output using a robust,
524+
AST-based parser to extract tool calls from the generated string.
525+
526+
**Parameters:**
527+
528+
* **`system_tool_prompt`**
529+
(`Callable[[list[ToolDefinition]], str] | str | None`, default:
530+
`None`
531+
)
532+
–A callable or string that generates the system prompt for tools.
533+
* **`tool_responses_as_user_messages`**
534+
(`bool`, default:
535+
`True`
536+
)
537+
–If True, tool responses will be converted to user messages wrapped in tool response tags.
538+
* **`tool_response_tag`**
539+
(`str`, default:
540+
`'tool-response'`
541+
)
542+
–The tag to use for tool responses in user messages.
543+
544+
**Returns:**
545+
546+
* `Transform`
547+
–A transform function that processes messages and generate params.
548+
549+
<Accordion title="Source code in rigging/transform/pythonic_tools.py" icon="code">
550+
```python
551+
def make_tools_to_pythonic_transform(
552+
*,
553+
system_tool_prompt: t.Callable[[list[ToolDefinition]], str] | str | None = None,
554+
tool_responses_as_user_messages: bool = True,
555+
tool_response_tag: str = "tool-response",
556+
) -> Transform:
557+
"""
558+
Create a transform that converts tool calls to a pythonic list format.
559+
560+
This transform will:
561+
1. Inject a system prompt with tool definitions serialized as JSON.
562+
2. Convert existing tool calls in messages to `[my_func(arg=...)]` format.
563+
3. Convert tool result messages into `<tool-response>` blocks in a user message (optional).
564+
4. In the post-transform, parse the model's output using a robust,
565+
AST-based parser to extract tool calls from the generated string.
566+
567+
Args:
568+
system_tool_prompt: A callable or string that generates the system prompt for tools.
569+
tool_responses_as_user_messages: If True, tool responses will be converted to user messages wrapped in tool response tags.
570+
tool_response_tag: The tag to use for tool responses in user messages.
571+
572+
Returns:
573+
A transform function that processes messages and generate params.
574+
"""
575+
576+
system_tool_prompt = system_tool_prompt or pythonic_tools_prompt
577+
578+
tool_response_cls = pydantic_xml_create_model(
579+
"ToolResponse",
580+
__base__=ToolResponse,
581+
__cls_kwargs__={"tag": tool_response_tag},
582+
__tag__=tool_response_tag,
583+
)
584+
585+
async def tools_to_pythonic_transform(
586+
messages: list[Message],
587+
params: GenerateParams,
588+
) -> tuple[list[Message], GenerateParams, PostTransform | None]:
589+
# Inject tool definitions into the system prompt
590+
591+
system_prompt = (
592+
system_tool_prompt
593+
if isinstance(system_tool_prompt, str)
594+
else system_tool_prompt(params.tools or [])
595+
)
596+
messages = inject_system_content(messages, system_prompt)
597+
598+
# Render existing tool calls and responses
599+
600+
updated_messages: list[Message] = []
601+
602+
for is_tool_group, message_group in itertools.groupby(
603+
messages, key=lambda m: tool_responses_as_user_messages and m.role == "tool"
604+
):
605+
if is_tool_group:
606+
user_message = Message(role="user", content="")
607+
for tool_message in message_group:
608+
user_message.append_slice(
609+
tool_response_cls(
610+
id=tool_message.tool_call_id or "",
611+
result=tool_message.content,
612+
),
613+
"tool_response",
614+
metadata={"id": tool_message.tool_call_id or ""},
615+
)
616+
updated_messages.append(user_message)
617+
continue
618+
619+
for message in message_group:
620+
if not message.tool_calls:
621+
updated_messages.append(message)
622+
continue
623+
624+
updated_message = message.clone()
625+
rendered_calls = [
626+
_render_tool_call_to_pythonic_string(tc) for tc in message.tool_calls
627+
]
628+
updated_message.tool_calls = None
629+
updated_message.append_slice(
630+
f"[{', '.join(rendered_calls)}]",
631+
"tool_call",
632+
metadata={
633+
"id": message.tool_calls[0].id or ""
634+
}, # TODO(nick): Handle multiple tool call slices
635+
)
636+
updated_messages.append(updated_message)
637+
638+
# Save any existing tool params
639+
640+
existing_tool_definitions = params.tools
641+
params.tools = None
642+
existing_tool_choice = params.tool_choice
643+
params.tool_choice = None
644+
645+
# Build post transform
646+
647+
async def pythonic_to_tools_transform(chat: "Chat") -> "Chat":
648+
# Convert the tool calls and strip them
649+
650+
for message in [m for m in chat.all if m.role == "assistant"]:
651+
# Restore original tool calls - fast path for efficiency and consistency
652+
653+
for slice_ in message.slices:
654+
if slice_.type == "tool_call" and isinstance(slice_.obj, ToolCall):
655+
message.tool_calls = message.tool_calls or []
656+
message.tool_calls.append(slice_.obj)
657+
message.remove_slices(slice_)
658+
659+
# Otherwise, find any new tool calls in the content
660+
661+
candidates = _extract_bracketed_blocks(message.content)
662+
parsed_results: list[tuple[str, list[ToolCall]]] = []
663+
for candidate_str in candidates:
664+
if parsed_calls := _attempt_parse_tool_calls_from_string(candidate_str):
665+
parsed_results.append((candidate_str, parsed_calls)) # noqa: PERF401
666+
667+
if not parsed_results:
668+
continue
669+
670+
# NOTE(nick): We only take the last successfully parsed block
671+
tool_calls_str, tool_calls = parsed_results[-1]
672+
message.tool_calls = tool_calls
673+
message.remove_slices()
674+
message.content = message.content.replace(tool_calls_str, "").strip()
675+
676+
# Convert our tool responses
677+
678+
updated_messages = []
679+
for message in messages:
680+
if message.role != "user" or not (
681+
tool_responses := message.try_parse_set(tool_response_cls)
682+
):
683+
updated_messages.append(message)
684+
continue
685+
686+
for tool_response in tool_responses:
687+
updated_messages.append( # noqa: PERF401
688+
Message(
689+
role="tool",
690+
content=tool_response.result,
691+
tool_call_id=tool_response.id,
692+
)
693+
)
694+
695+
# Restore the params
696+
697+
chat.params = chat.params or GenerateParams()
698+
chat.params.tools = existing_tool_definitions
699+
chat.params.tool_choice = existing_tool_choice
700+
701+
# Strip the system prompt content
702+
703+
chat.messages = strip_system_content(updated_messages, system_prompt)
704+
705+
return chat
706+
707+
return updated_messages, params, pythonic_to_tools_transform
708+
709+
return tools_to_pythonic_transform
710+
```
711+
712+
463713
</Accordion>
464714

465715
make\_tools\_to\_xml\_transform

docs/topics/tools.mdx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ The `mode` parameter in `.using()` controls how Rigging interacts with the langu
219219
- **`xml`:** Rigging injects instructions and an XML schema into the prompt, telling the model how to format its output to request a tool call using specific XML tags. Rigging parses this XML.
220220
- **`json-in-xml`:** Similar to `xml`, but the model is instructed to place a JSON object containing the arguments within the XML tags.
221221
- **`json-with-tag`:** Similar to `json`, but the JSON object is wrapped in a specific tag (e.g., `<tool_call>`).
222-
- **`json`:** The model is instructed to output tool calls as raw JSON anywhere in the message content with `name` and `arguments` fields
222+
- **`json`:** The model is instructed to output tool calls as raw JSON anywhere in the message content with `name` and `arguments` fields.
223+
- **`pythonic`:** The model is instructed to output tool calls as Python-style function calls in the message content.
223224

224225
Generally, `auto` is recommended as it leverages the most efficient method available.
225226

0 commit comments

Comments
 (0)