Skip to content

Commit 95cc796

Browse files
authored
fix: Tool Calling w/ Parsing (#108)
* Added new Stop exception for breaking from recursive tool calls. * fix formatting
1 parent 2cbecd8 commit 95cc796

File tree

6 files changed

+101
-50
lines changed

6 files changed

+101
-50
lines changed

rigging/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MapCompletionCallback,
1515
ThenCompletionCallback,
1616
)
17+
from rigging.error import Stop
1718
from rigging.generator import (
1819
GeneratedMessage,
1920
GeneratedText,
@@ -64,6 +65,7 @@
6465
"PipelineStepContextManager",
6566
"PipelineStepGenerator",
6667
"Prompt",
68+
"Stop",
6769
"ThenChatCallback",
6870
"ThenCompletionCallback",
6971
"Tool",

rigging/chat.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def __init__(
720720
self.tool_mode: ToolMode = "auto"
721721
self.api_tool_choice: ApiToolChoice | None = None
722722
self.inject_tool_prompt = True
723-
self.stop_on_tool_calls = True
723+
self.add_tool_stop_token = True
724724
self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
725725
self.map_callbacks: list[tuple[MapChatCallback, int]] = []
726726
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
@@ -880,6 +880,7 @@ def clone(
880880
*,
881881
only_messages: bool = False,
882882
chat: Chat | None = None,
883+
callbacks: bool | t.Sequence[MapChatCallback | ThenChatCallback] = True,
883884
) -> "ChatPipeline":
884885
"""
885886
Creates a clone of the current `ChatPipeline` instance.
@@ -890,6 +891,8 @@ def clone(
890891
including until callbacks, types, tools, metadata, etc.
891892
chat: An optional chat object clone for use in the new pipeline, otherwise the current
892893
internal chat object will be cloned.
894+
callbacks: If True (default), all callbacks will be cloned. If False, no callbacks will be cloned.
895+
Otherwise provide a sequence of callbacks which should be maintained in the new pipeline.
893896
894897
Returns:
895898
The cloned ChatPipeline.
@@ -906,16 +909,20 @@ def clone(
906909
new.tools = self.tools.copy()
907910
new.tool_mode = self.tool_mode
908911
new.metadata = deepcopy(self.metadata)
909-
new.map_callbacks = self.map_callbacks.copy()
910912
new.on_failed = self.on_failed
911913
new.errors_to_catch = self.errors_to_catch.copy()
912914
new.errors_to_exclude = self.errors_to_exclude.copy()
913915
new.caching = self.caching
914916

917+
new.watch_callbacks = self.watch_callbacks.copy()
918+
915919
# Check if any of our callbacks are bound methods to a ChatPipline.
916920
# If so, we should rebind them to `self` to ensure they work correctly
917921
# and aren't operating with old state.
918922

923+
if callbacks is False:
924+
return new
925+
919926
new.then_callbacks = [
920927
(callback, max_depth)
921928
if not hasattr(callback, "__self__")
@@ -931,6 +938,18 @@ def clone(
931938
for callback, max_depth in self.map_callbacks.copy()
932939
]
933940

941+
if not isinstance(callbacks, bool):
942+
new.then_callbacks = [
943+
(callback, max_depth)
944+
for callback, max_depth in self.then_callbacks
945+
if callback in callbacks
946+
]
947+
new.map_callbacks = [
948+
(callback, max_depth)
949+
for callback, max_depth in self.map_callbacks
950+
if callback in callbacks
951+
]
952+
934953
return new
935954

936955
def meta(self, **kwargs: t.Any) -> "ChatPipeline":
@@ -1105,7 +1124,7 @@ def using(
11051124
mode: ToolMode | None = None,
11061125
choice: ApiToolChoice | None = None,
11071126
max_depth: int = DEFAULT_MAX_DEPTH,
1108-
stop_on_tool_calls: bool | None = None,
1127+
add_stop_token: bool | None = None,
11091128
) -> "ChatPipeline":
11101129
"""
11111130
Adds a tool or a sequence of tools to participate in the generation process.
@@ -1119,7 +1138,8 @@ def using(
11191138
mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
11201139
choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
11211140
max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
1122-
stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
1141+
add_stop_token: When using natively parsed tools ("xml", "json-in-xml"), use stop tokens to
1142+
immediately process a tool call when observed.
11231143
11241144
Returns:
11251145
The updated pipeline.
@@ -1172,8 +1192,8 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11721192
if choice is not None:
11731193
self.api_tool_choice = choice
11741194

1175-
if stop_on_tool_calls is not None:
1176-
self.stop_on_tool_calls = stop_on_tool_calls
1195+
if add_stop_token is not None:
1196+
self.add_tool_stop_token = add_stop_token
11771197

11781198
return self
11791199

@@ -1237,7 +1257,7 @@ def until_parsed_as(
12371257

12381258
async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
12391259
if (
1240-
self.stop_on_tool_calls
1260+
self.add_tool_stop_token
12411261
and self.tool_mode in ["xml", "json-in-xml"]
12421262
and chat.stop_reason == "stop"
12431263
):
@@ -1270,54 +1290,49 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
12701290
if not tool_calls:
12711291
return None
12721292

1273-
next_pipeline = self.clone(chat=chat)
1293+
next_pipeline = self.clone(chat=chat, callbacks=[self._then_tools])
12741294

1275-
should_continue = True
1295+
stop = False
12761296

12771297
for tool_call in tool_calls:
12781298
tool = next((t for t in self.tools if t.name == tool_call.name), None)
12791299
if tool is None:
12801300
raise UnknownToolError(tool_call.name)
12811301

1282-
message, _should_continue = await tool.handle_tool_call(tool_call)
1302+
message, _stop = await tool.handle_tool_call(tool_call)
1303+
stop = _stop if not _stop else stop
12831304
next_pipeline.add(message)
12841305

1285-
# If the tool returns none, we should resolve tool calls, but
1286-
# not continue the pipeline.
1287-
1288-
if not _should_continue:
1289-
should_continue = _should_continue
1290-
12911306
# Need to prevent infinite loops and treat tool_choice like
12921307
# an ephemeral setting which resets after the first tool call.
12931308

12941309
if self.tool_mode == "api" and next_pipeline.params:
12951310
next_pipeline.params.tool_choice = None
12961311

1297-
if not should_continue:
1312+
if stop:
12981313
# TODO(nick): Type hints here stop us from mixing step generators
12991314
# and basic chat returns.
13001315
return next_pipeline.chat # type: ignore [return-value]
13011316

13021317
return next_pipeline.step()
13031318

13041319
async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
1305-
next_pipeline = self.clone(chat=chat)
1320+
next_pipeline = self.clone(chat=chat, callbacks=[self._then_parse])
13061321

13071322
try:
13081323
chat.last.parse_many(*self.until_types)
13091324
except ValidationError as e:
13101325
next_pipeline.add(
13111326
Message.from_model(
13121327
ValidationErrorModel(content=str(e)),
1313-
suffix="Rewrite your entire message with all the required xml structure.",
1328+
suffix="Rewrite your entire message with all of the required xml elements.",
13141329
),
13151330
)
13161331
except Exception as e: # noqa: BLE001
13171332
next_pipeline.add(
13181333
Message.from_model(
13191334
SystemErrorModel(content=str(e)),
1320-
suffix="Rewrite your entire message with all the required xml structure.",
1335+
suffix="Rewrite your entire message with all of the required xml elements.",
13211336
),
13221337
)
13231338
else: # parsed successfully
@@ -1336,7 +1351,7 @@ async def _pre_run(self) -> None:
13361351
self.chat.inject_tool_prompt(self.tools, self.tool_mode)
13371352
self.inject_native_tool_prompt = False
13381353

1339-
if self.stop_on_tool_calls:
1354+
if self.add_tool_stop_token:
13401355
self.params = self.params = GenerateParams()
13411356
self.params.stop = self.params.stop or []
13421357
self.params.stop.append(f"</{TOOL_CALLS_TAG}>")

rigging/error.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,38 @@
1414
from rigging.message import Message
1515

1616

17+
# User Throwable Exceptions
18+
19+
20+
class Stop(Exception): # noqa: N818
21+
"""
22+
Raise inside a pipeline to indicate a stopping condition.
23+
24+
Example:
25+
```
26+
import rigging as rg
27+
28+
async def read_file(path: str) -> str:
29+
"Read the contents of a file."
30+
31+
if no_more_files(path):
32+
raise rg.Stop("There are no more files to read.")
33+
34+
...
35+
36+
chat = await pipeline.using(read_file).run()
37+
```
38+
"""
39+
40+
def __init__(self, message: str):
41+
super().__init__(message)
42+
self.message = message
43+
"""The message associated with the stop."""
44+
45+
46+
# System Exceptions
47+
48+
1749
class UnknownToolError(Exception):
1850
"""
1951
Raised when the an api tool call is made for an unknown tool.

rigging/prompt.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,14 +566,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
566566
next_pipeline.add(
567567
Message.from_model(
568568
ValidationErrorModel(content=str(e)),
569-
suffix="Rewrite your entire message with all the required xml structure.",
569+
suffix="Rewrite your entire message with all of the required xml elements.",
570570
),
571571
)
572572
except Exception as e: # noqa: BLE001
573573
next_pipeline.add(
574574
Message.from_model(
575575
SystemErrorModel(content=str(e)),
576-
suffix="Rewrite your entire message with all the required xml structure.",
576+
suffix="Rewrite your entire message with all of the required xml elements.",
577577
),
578578
)
579579
else: # parsed successfully
@@ -1084,8 +1084,7 @@ def prompt(
10841084
generator_id: str | None = None,
10851085
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
10861086
system_prompt: str | None = None,
1087-
) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]:
1088-
...
1087+
) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]: ...
10891088

10901089

10911090
@t.overload
@@ -1098,8 +1097,7 @@ def prompt(
10981097
generator_id: str | None = None,
10991098
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
11001099
system_prompt: str | None = None,
1101-
) -> Prompt[P, R]:
1102-
...
1100+
) -> Prompt[P, R]: ...
11031101

11041102

11051103
@t.overload
@@ -1112,8 +1110,7 @@ def prompt(
11121110
generator_id: str | None = None,
11131111
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
11141112
system_prompt: str | None = None,
1115-
) -> Prompt[P, R]:
1116-
...
1113+
) -> Prompt[P, R]: ...
11171114

11181115

11191116
def prompt(
@@ -1214,8 +1211,12 @@ def make_prompt(
12141211

12151212

12161213
@t.overload
1217-
def make_prompt(content: str, return_type: type[R], *, ctx: Ctx | None = None) -> Prompt[..., R]:
1218-
...
1214+
def make_prompt(
1215+
content: str,
1216+
return_type: type[R],
1217+
*,
1218+
ctx: Ctx | None = None,
1219+
) -> Prompt[..., R]: ...
12191220

12201221

12211222
@t.overload
@@ -1224,8 +1225,7 @@ def make_prompt(
12241225
return_type: None = None,
12251226
*,
12261227
ctx: Ctx | None = None,
1227-
) -> Prompt[..., str]:
1228-
...
1228+
) -> Prompt[..., str]: ...
12291229

12301230

12311231
def make_prompt(

rigging/tool/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import typing_extensions as te
1515
from pydantic import TypeAdapter
1616

17-
from rigging.error import ToolDefinitionError
17+
from rigging.error import Stop, ToolDefinitionError
1818
from rigging.model import Model, make_from_schema, make_from_signature
1919
from rigging.tool.api import ApiFunctionDefinition, ApiToolCall, ApiToolDefinition
2020
from rigging.tool.native import (
@@ -258,7 +258,7 @@ def json_definition(self) -> JsonInXmlToolDefinition:
258258
parameters=json.dumps(self.parameters_schema),
259259
)
260260

261-
async def handle_tool_call( # noqa: PLR0912
261+
async def handle_tool_call( # noqa: PLR0912, PLR0915
262262
self,
263263
tool_call: ApiToolCall | XmlToolCall | JsonInXmlToolCall,
264264
) -> tuple["Message", bool]:
@@ -269,7 +269,8 @@ async def handle_tool_call( # noqa: PLR0912
269269
tool_call: The tool call to handle.
270270
271271
Returns:
272-
The message to send back to the generator or `None` if iterative tool calling should not proceed any further.
272+
A tuple containing the message to send back to the generator and a
273+
boolean indicating whether tool calling should stop.
273274
"""
274275

275276
from rigging.message import ContentText, ContentTypes, Message
@@ -330,10 +331,16 @@ async def handle_tool_call( # noqa: PLR0912
330331

331332
# Call the function
332333

334+
stop = False
335+
333336
try:
334337
result: t.Any = self.fn(**kwargs) # type: ignore [call-arg]
335338
if inspect.isawaitable(result):
336339
result = await result
340+
except Stop as e:
341+
result = f"<rg:stop>{e.message}</rg:stop>"
342+
span.set_attribute("stop", True)
343+
stop = True
337344
except Exception as e:
338345
if self.catch is True or (
339346
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
@@ -350,11 +357,6 @@ async def handle_tool_call( # noqa: PLR0912
350357
else Message("user")
351358
)
352359

353-
# If the tool returns nothing back to us, we'll assume that
354-
# they do not want to proceed with additional tool calling
355-
356-
should_continue = result is not None
357-
358360
# If the tool gave us back anything that looks like a message, we'll
359361
# just pass it along. Otherwise we need to box up the result.
360362

@@ -395,7 +397,7 @@ async def handle_tool_call( # noqa: PLR0912
395397
result=message.content_parts[0].text,
396398
).to_pretty_xml()
397399

398-
return message, should_continue
400+
return message, stop
399401

400402
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
401403
return self.fn(*args, **kwargs)

0 commit comments

Comments
 (0)