Skip to content

Commit 75421ef

Browse files
committed
Bunch of tool mechanic updates and bug fixes.
1 parent 311dfe3 commit 75421ef

File tree

6 files changed

+374
-81
lines changed

6 files changed

+374
-81
lines changed

rigging/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from rigging.message import ContentImageUrl, ContentText, Message, MessageDict, Messages
2929
from rigging.model import Model, attr, element, wrapped
3030
from rigging.prompt import Ctx, Prompt, prompt
31-
from rigging.tool import Tool, mcp, robopages, tool
31+
from rigging.tool import Tool, mcp, robopages, tool, tool_method
3232
from rigging.util import await_
3333
from rigging.version import VERSION
3434

@@ -66,6 +66,7 @@
6666
"error",
6767
"parsing",
6868
"tool",
69+
"tool_method",
6970
"logging",
7071
"await_",
7172
"interact",

rigging/chat.py

Lines changed: 132 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020

2121
from elasticsearch import AsyncElasticsearch
2222
from loguru import logger
23-
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, ValidationError, computed_field
23+
from pydantic import (
24+
BaseModel,
25+
ConfigDict,
26+
Field,
27+
PlainSerializer,
28+
ValidationError,
29+
computed_field,
30+
)
2431

2532
from rigging.error import MaxDepthError, UnknownToolError
2633
from rigging.generator import GenerateParams, Generator, get_generator
@@ -30,6 +37,7 @@
3037
from rigging.tool.api import ApiToolCall, ApiToolChoice
3138
from rigging.tool.base import Tool, ToolMode
3239
from rigging.tool.native import (
40+
TOOL_CALLS_TAG,
3341
JsonInXmlToolCall,
3442
JsonInXmlToolDefinition,
3543
XmlToolCall,
@@ -101,10 +109,17 @@ class Chat(BaseModel):
101109
params: GenerateParams | None = Field(None, exclude=True, repr=False)
102110
"""Any additional generation params used for this chat."""
103111

104-
error: t.Annotated[
105-
BaseException,
106-
PlainSerializer(lambda x: str(x), return_type=str, when_used="json-unless-none"),
107-
] | None = Field(None, repr=False)
112+
error: (
113+
t.Annotated[
114+
BaseException,
115+
PlainSerializer(
116+
lambda x: str(x),
117+
return_type=str,
118+
when_used="json-unless-none",
119+
),
120+
]
121+
| None
122+
) = Field(None, repr=False)
108123
"""Holds any exception that was caught during the generation pipeline."""
109124
failed: bool = Field(False, exclude=False, repr=True)
110125
"""
@@ -199,7 +214,10 @@ def message_dicts(self) -> list[MessageDict]:
199214
The MessageDict list
200215
"""
201216
return [
202-
t.cast(MessageDict, m.model_dump(include={"role", "content_parts"}, exclude_none=True))
217+
t.cast(
218+
MessageDict,
219+
m.model_dump(include={"role", "content_parts"}, exclude_none=True),
220+
)
203221
for m in self.all
204222
]
205223

@@ -364,7 +382,11 @@ def inject_system_content(self, content: str) -> "Chat":
364382
self.messages[0].content += "\n\n" + content
365383
return self
366384

367-
def inject_tool_prompt(self, tools: t.Sequence[Tool[..., t.Any]], mode: ToolMode) -> "Chat":
385+
def inject_tool_prompt(
386+
self,
387+
tools: t.Sequence[Tool[..., t.Any]],
388+
mode: ToolMode,
389+
) -> "Chat":
368390
"""
369391
Injects a default tool use prompt into the system prompt.
370392
@@ -699,14 +721,19 @@ def __init__(
699721
self.tool_mode: ToolMode = "auto"
700722
self.api_tool_choice: ApiToolChoice | None = None
701723
self.inject_tool_prompt = True
724+
self.stop_on_tool_calls = True
702725
self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
703726
self.map_callbacks: list[tuple[MapChatCallback, int]] = []
704727
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
705728

706729
def __len__(self) -> int:
707730
return len(self.chat)
708731

709-
def with_(self, params: GenerateParams | None = None, **kwargs: t.Any) -> "ChatPipeline":
732+
def with_(
733+
self,
734+
params: GenerateParams | None = None,
735+
**kwargs: t.Any,
736+
) -> "ChatPipeline":
710737
"""
711738
Assign specific generation parameter overloads for this chat.
712739
@@ -850,7 +877,12 @@ def fork(
850877
"""
851878
return self.clone().add(messages)
852879

853-
def clone(self, *, only_messages: bool = False, chat: Chat | None = None) -> "ChatPipeline":
880+
def clone(
881+
self,
882+
*,
883+
only_messages: bool = False,
884+
chat: Chat | None = None,
885+
) -> "ChatPipeline":
854886
"""
855887
Creates a clone of the current `ChatPipeline` instance.
856888
@@ -1036,7 +1068,10 @@ def apply_to_all(self, **kwargs: str) -> "ChatPipeline":
10361068
new.chat.apply_to_all(**kwargs)
10371069
return new
10381070

1039-
def cache(self, mode: CacheMode | None | t.Literal[False] = "latest") -> "ChatPipeline":
1071+
def cache(
1072+
self,
1073+
mode: CacheMode | None | t.Literal[False] = "latest",
1074+
) -> "ChatPipeline":
10401075
"""
10411076
Sets the caching mode for the pipeline.
10421077
@@ -1072,6 +1107,7 @@ def using(
10721107
mode: ToolMode | None = None,
10731108
choice: ApiToolChoice | None = None,
10741109
max_depth: int = DEFAULT_MAX_DEPTH,
1110+
stop_on_tool_calls: bool | None = None,
10751111
) -> "ChatPipeline":
10761112
"""
10771113
Adds a tool or a sequence of tools to participate in the generation process.
@@ -1085,6 +1121,7 @@ def using(
10851121
mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
10861122
choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
10871123
max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
1124+
stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
10881125
10891126
Returns:
10901127
The updated pipeline.
@@ -1115,7 +1152,9 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11151152
existing_names = {tool.name for tool in self.tools}
11161153
for tool in new_tools:
11171154
if tool.name in existing_names:
1118-
raise ValueError(f"Tool with name '{tool.name}' already exists in the pipeline.")
1155+
raise ValueError(
1156+
f"Tool with name '{tool.name}' already exists in the pipeline.",
1157+
)
11191158

11201159
self.tools += new_tools
11211160

@@ -1124,14 +1163,20 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11241163
for callback, max_depth in self.then_callbacks
11251164
if callback != self._then_tools
11261165
]
1127-
self.then_callbacks.insert(0, (self._then_tools, max_depth)) # make sure this is first
1166+
self.then_callbacks.insert(
1167+
0,
1168+
(self._then_tools, max_depth),
1169+
) # make sure this is first
11281170

11291171
if mode is not None:
11301172
self.tool_mode = mode
11311173

11321174
if choice is not None:
11331175
self.api_tool_choice = choice
11341176

1177+
if stop_on_tool_calls is not None:
1178+
self.stop_on_tool_calls = stop_on_tool_calls
1179+
11351180
return self
11361181

11371182
def until_parsed_as(
@@ -1193,7 +1238,30 @@ def until_parsed_as(
11931238
# Internal callbacks for handling tools and parsing
11941239

11951240
async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
1196-
tool_calls: list[ApiToolCall] | list[XmlToolCall] | list[JsonInXmlToolCall] | None = None
1241+
if (
1242+
self.stop_on_tool_calls
1243+
and self.tool_mode in ["xml", "json-in-xml"]
1244+
and chat.stop_reason == "stop"
1245+
):
1246+
# If we:
1247+
# 1. Are using native tools
1248+
# 2. Set a stop token for the tool calls
1249+
# 3. Hit that stop token
1250+
#
1251+
# Then we should re-inject the closing tag for completeness.
1252+
1253+
for part in chat.last.content_parts:
1254+
if (
1255+
part.type == "text"
1256+
and f"<{TOOL_CALLS_TAG}>" in part.text
1257+
and f"</{TOOL_CALLS_TAG}>" not in part.text
1258+
):
1259+
part.text += f"</{TOOL_CALLS_TAG}>"
1260+
break
1261+
1262+
# Parse the actual tool calls
1263+
1264+
tool_calls: (list[ApiToolCall] | list[XmlToolCall] | list[JsonInXmlToolCall] | None) = None
11971265
if self.tool_mode == "api":
11981266
tool_calls = chat.last.tool_calls
11991267
if self.tool_mode == "xml":
@@ -1265,9 +1333,15 @@ async def _pre_run(self) -> None:
12651333
if self.tool_mode == "auto" and self.tools:
12661334
self.tool_mode = "api" if await self.generator.supports_function_calling() else "xml"
12671335

1268-
if self.tools and self.tool_mode in ["xml", "json-in-xml"] and self.inject_tool_prompt:
1269-
self.chat.inject_tool_prompt(self.tools, self.tool_mode)
1270-
self.inject_native_tool_prompt = False
1336+
if self.tools and self.tool_mode in ["xml", "json-in-xml"]:
1337+
if self.inject_tool_prompt:
1338+
self.chat.inject_tool_prompt(self.tools, self.tool_mode)
1339+
self.inject_native_tool_prompt = False
1340+
1341+
if self.stop_on_tool_calls:
1342+
self.params = self.params = GenerateParams()
1343+
self.params.stop = self.params.stop or []
1344+
self.params.stop.append(f"</{TOOL_CALLS_TAG}>")
12711345

12721346
if self.tools and self.tool_mode == "api":
12731347
if self.params is None:
@@ -1287,17 +1361,25 @@ def _fit_params(
12871361
params = [self.params.merge_with(p) for p in params]
12881362
return [(p or GenerateParams()) for p in params]
12891363

1290-
def _apply_cache_mode_to_messages(self, messages: list[list[Message]]) -> list[list[Message]]:
1364+
def _apply_cache_mode_to_messages(
1365+
self,
1366+
messages: list[list[Message]],
1367+
) -> list[list[Message]]:
12911368
if self.caching is None:
12921369
return messages
12931370

12941371
if self.caching != "latest":
1295-
logger.warning(f"Unknown caching mode '{self.caching}', defaulting to 'latest'")
1372+
logger.warning(
1373+
f"Unknown caching mode '{self.caching}', defaulting to 'latest'",
1374+
)
12961375

12971376
# first remove existing cache settings
12981377
updated: list[list[Message]] = []
12991378
for _messages in messages:
1300-
updated = [*updated, [m.clone().cache(cache_control=False) for m in _messages]]
1379+
updated = [
1380+
*updated,
1381+
[m.clone().cache(cache_control=False) for m in _messages],
1382+
]
13011383

13021384
# then apply the latest cache settings
13031385
for _messages in updated:
@@ -1463,15 +1545,6 @@ async def _step( # noqa: PLR0915, PLR0912
14631545

14641546
# Check if we should immediately raise
14651547

1466-
# FailMode = t.Literal["raise", "skip", "include"]
1467-
# self.on_failed: FailMode = "raise"
1468-
# """How to handle failures in the pipeline unless overriden in calls."""
1469-
1470-
# self.errors_to_catch: set[type[Exception]] = {MaxDepthError, ValidationError}
1471-
# """The list of exceptions to catch during generation if you are including or skipping failures."""
1472-
# self.errors_to_exclude: set[type[Exception]] = set()
1473-
# """The list of exceptions to exclude from the catch list."""
1474-
14751548
for chat in chats:
14761549
if chat.error is not None and (
14771550
on_failed == "raise"
@@ -1532,7 +1605,11 @@ async def _step( # noqa: PLR0915, PLR0912
15321605
step = state.step.with_parent(current_step)
15331606

15341607
if step.depth > max_depth:
1535-
max_depth_error = MaxDepthError(max_depth, step, callback_name)
1608+
max_depth_error = MaxDepthError(
1609+
max_depth,
1610+
step,
1611+
callback_name,
1612+
)
15361613
if on_failed == "raise":
15371614
raise max_depth_error
15381615

@@ -1597,12 +1674,18 @@ async def _step( # noqa: PLR0915, PLR0912
15971674
if inspect.isasyncgen(chats_or_generator):
15981675
generator = t.cast(
15991676
PipelineStepGenerator,
1600-
await exit_stack.enter_async_context(aclosing(chats_or_generator)),
1677+
await exit_stack.enter_async_context(
1678+
aclosing(chats_or_generator),
1679+
),
16011680
)
16021681
async for step in generator:
16031682
_step = step.with_parent(current_step)
16041683
if _step.depth > max_depth:
1605-
max_depth_error = MaxDepthError(max_depth, _step, callback_name)
1684+
max_depth_error = MaxDepthError(
1685+
max_depth,
1686+
_step,
1687+
callback_name,
1688+
)
16061689
if on_failed == "raise":
16071690
raise max_depth_error
16081691

@@ -1676,10 +1759,17 @@ async def step(
16761759
generator_id=self.generator.to_identifier(),
16771760
params=self.params.to_dict() if self.params is not None else {},
16781761
) as span:
1679-
async with aclosing(self._step(span, messages, params, on_failed)) as generator:
1762+
async with aclosing(
1763+
self._step(span, messages, params, on_failed),
1764+
) as generator:
16801765
yield generator
16811766

1682-
async def run(self, *, on_failed: FailMode | None = None, allow_failed: bool = False) -> Chat:
1767+
async def run(
1768+
self,
1769+
*,
1770+
on_failed: FailMode | None = None,
1771+
allow_failed: bool = False,
1772+
) -> Chat:
16831773
"""
16841774
Execute the generation process for a single message.
16851775
@@ -1749,7 +1839,9 @@ async def step_many(
17491839
generator_id=self.generator.to_identifier(),
17501840
params=self.params.to_dict() if self.params is not None else {},
17511841
) as span:
1752-
async with aclosing(self._step(span, messages, params, on_failed)) as generator:
1842+
async with aclosing(
1843+
self._step(span, messages, params, on_failed),
1844+
) as generator:
17531845
yield generator
17541846

17551847
async def run_many(
@@ -1827,7 +1919,9 @@ async def step_batch(
18271919
messages = [[*self.chat.all, *Message.fit_as_list(m)] for m in many]
18281920
if len(messages) < count:
18291921
if len(messages) != 1:
1830-
raise ValueError(f"Can't fit {len(messages)} messages to {count} params")
1922+
raise ValueError(
1923+
f"Can't fit {len(messages)} messages to {count} params",
1924+
)
18311925
messages = messages * count
18321926

18331927
params = self._fit_params(count, params)
@@ -1838,7 +1932,9 @@ async def step_batch(
18381932
generator_id=self.generator.to_identifier(),
18391933
params=self.params.to_dict() if self.params is not None else {},
18401934
) as span:
1841-
async with aclosing(self._step(span, messages, params, on_failed)) as generator:
1935+
async with aclosing(
1936+
self._step(span, messages, params, on_failed),
1937+
) as generator:
18421938
yield generator
18431939

18441940
async def run_batch(

0 commit comments

Comments
 (0)