Skip to content

Commit 93fc7e2

Browse files
authored
fix: Tool Error Handling (#222)
* Reworking error handling for pipeline tools/parsing to give more control * Fixing tests
1 parent 0651021 commit 93fc7e2

File tree

11 files changed

+580
-152
lines changed

11 files changed

+580
-152
lines changed

docs/api/chat.mdx

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,13 +1268,20 @@ def __init__(
12681268
self.scorers: list[Scorer[Chat]] = []
12691269
"""List of dreadnode scorers to evaluate the generated chat upon completion."""
12701270

1271-
self.until_types: list[type[Model]] = []
1271+
self.until_parsed_as_types: list[type[Model]] = []
1272+
self.until_parsed_as_catch: bool = True
12721273
self.tools: list[Tool[..., t.Any]] = []
12731274
self.tool_mode: ToolMode = "auto"
12741275
self.inject_tool_prompt = True
12751276
self.add_tool_stop_token = True
1276-
self.then_callbacks: list[tuple[ThenChatCallback, int, bool]] = []
1277-
self.map_callbacks: list[tuple[MapChatCallback, int, bool]] = []
1277+
self.then_callbacks: list[
1278+
# callback, max_depth, as_task
1279+
tuple[ThenChatCallback, int, bool]
1280+
] = []
1281+
self.map_callbacks: list[
1282+
# callback, max_depth, as_task
1283+
tuple[MapChatCallback, int, bool]
1284+
] = []
12781285
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
12791286
self.transforms: list[Transform] = []
12801287
```
@@ -1726,7 +1733,7 @@ def clone(
17261733
)
17271734
new.chat = (chat or self.chat).clone()
17281735
if not only_messages:
1729-
new.until_types = self.until_types.copy()
1736+
new.until_parsed_as_types = self.until_parsed_as_types.copy()
17301737
new.tools = self.tools.copy()
17311738
new.tool_mode = self.tool_mode
17321739
new.metadata = deepcopy(self.metadata)
@@ -3084,6 +3091,7 @@ def transform(
30843091
until_parsed_as(
30853092
*types: type[ModelT],
30863093
max_depth: int = DEFAULT_MAX_DEPTH,
3094+
catch: bool | None = None,
30873095
attempt_recovery: bool | None = None,
30883096
drop_dialog: bool | None = None,
30893097
max_rounds: int | None = None,
@@ -3105,6 +3113,11 @@ before the generation process completes.
31053113
`DEFAULT_MAX_DEPTH`
31063114
)
31073115
–The maximum depth to re-attempt parsing using recursive pipelines (this is shared between all types).
3116+
* **`catch`**
3117+
(`bool | None`, default:
3118+
`None`
3119+
)
3120+
–Whether to catch exceptions and return them as messages automatically, otherwise raise them to the pipeline.
31083121
* **`attempt_recovery`**
31093122
(`bool | None`, default:
31103123
`None`
@@ -3132,6 +3145,7 @@ def until_parsed_as(
31323145
self,
31333146
*types: type[ModelT],
31343147
max_depth: int = DEFAULT_MAX_DEPTH,
3148+
catch: bool | None = None,
31353149
# deprecated
31363150
attempt_recovery: bool | None = None,
31373151
drop_dialog: bool | None = None,
@@ -3144,6 +3158,7 @@ def until_parsed_as(
31443158
Args:
31453159
*types: The type or types of models to wait for.
31463160
max_depth: The maximum depth to re-attempt parsing using recursive pipelines (this is shared between all types).
3161+
catch: Whether to catch exceptions and return them as messages automatically, otherwise raise them to the pipeline.
31473162
attempt_recovery: deprecated, recovery is always attempted.
31483163
drop_dialog: deprecated, the full dialog is always returned.
31493164
max_rounds: deprecated, use `max_depth` instead.
@@ -3170,7 +3185,8 @@ def until_parsed_as(
31703185
stacklevel=2,
31713186
)
31723187

3173-
self.until_types = list(types)
3188+
self.until_parsed_as_types = list(types)
3189+
self.until_parsed_as_catch = catch or self.until_parsed_as_catch
31743190

31753191
max_depth = max_rounds or max_depth
31763192
self.then_callbacks = [
@@ -3197,6 +3213,7 @@ using(
31973213
choice: ToolChoice | None = None,
31983214
max_depth: int = DEFAULT_MAX_DEPTH,
31993215
add_stop_token: bool | None = None,
3216+
catch: bool | Iterable[type[Exception]] | None = None,
32003217
) -> ChatPipeline
32013218
```
32023219

@@ -3240,6 +3257,13 @@ wrapped in xml tags.
32403257
)
32413258
–When using "xml" tool transforms, use stop tokens to
32423259
immediately process a tool call when observed.
3260+
* **`catch`**
3261+
(`bool | Iterable[type[Exception]] | None`, default:
3262+
`None`
3263+
)
3264+
–Override the catch setting for all incoming tools, or leave `None` to use the tool's default.
3265+
By default, catches `json.JSONDecodeError` and `ValidationError`. Set to `{}` to let the pipeline
3266+
handle all tool exceptions.
32433267

32443268
**Returns:**
32453269

@@ -3274,6 +3298,7 @@ def using(
32743298
choice: ToolChoice | None = None,
32753299
max_depth: int = DEFAULT_MAX_DEPTH,
32763300
add_stop_token: bool | None = None,
3301+
catch: bool | t.Iterable[type[Exception]] | None = None,
32773302
) -> "ChatPipeline":
32783303
"""
32793304
Adds a tool or a sequence of tools to participate in the generation process.
@@ -3294,6 +3319,9 @@ def using(
32943319
max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
32953320
add_stop_token: When using "xml" tool transforms, use stop tokens to
32963321
immediately process a tool call when observed.
3322+
catch: Override the catch setting for all incoming tools, or leave `None` to use the tool's default.
3323+
By default, catches `json.JSONDecodeError` and `ValidationError`. Set to `{}` to let the pipeline
3324+
handle all tool exceptions.
32973325
32983326
Returns:
32993327
The updated pipeline.
@@ -3332,6 +3360,9 @@ def using(
33323360
else:
33333361
_tools.append(tool)
33343362

3363+
if catch is not None:
3364+
_tools = [tool.with_(catch=catch) for tool in _tools]
3365+
33353366
existing_names = {tool.name for tool in self.tools}
33363367
new_names = {tool.name for tool in _tools}
33373368
for name in existing_names & new_names:

docs/api/message.mdx

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,8 @@ from_model(
11751175
models: Model | Sequence[Model],
11761176
role: Role = "user",
11771177
suffix: str | None = None,
1178+
tool_call_id: str | None = None,
1179+
metadata: dict[str, Any] | None = None,
11781180
) -> Message
11791181
```
11801182

@@ -1195,6 +1197,16 @@ Create a Message object from one or more Model objects.
11951197
`None`
11961198
)
11971199
–A suffix to append to the content.
1200+
* **`metadata`**
1201+
(`dict[str, Any] | None`, default:
1202+
`None`
1203+
)
1204+
–Additional metadata for the Message.
1205+
* **`tool_call_id`**
1206+
(`str | None`, default:
1207+
`None`
1208+
)
1209+
–The ID of the tool call associated with this message.
11981210

11991211
**Returns:**
12001212

@@ -1209,6 +1221,8 @@ def from_model(
12091221
models: Model | t.Sequence[Model],
12101222
role: Role = "user",
12111223
suffix: str | None = None,
1224+
tool_call_id: str | None = None,
1225+
metadata: dict[str, t.Any] | None = None,
12121226
) -> "Message":
12131227
"""
12141228
Create a Message object from one or more Model objects.
@@ -1217,6 +1231,8 @@ def from_model(
12171231
models: The Model object(s) to convert to a Message.
12181232
role: The role of the Message.
12191233
suffix: A suffix to append to the content.
1234+
metadata: Additional metadata for the Message.
1235+
tool_call_id: The ID of the tool call associated with this message.
12201236
12211237
Returns:
12221238
The created Message object.
@@ -1240,7 +1256,18 @@ def from_model(
12401256
if suffix is not None:
12411257
content += f"\n{suffix}"
12421258

1243-
return cls(role=role, content=content, slices=slices_)
1259+
# If we building this message from an error, add
1260+
# the error content to the metadata
1261+
if isinstance(models, ErrorModel):
1262+
metadata = {"error": models.content}
1263+
1264+
return cls(
1265+
role=role,
1266+
content=content,
1267+
slices=slices_,
1268+
tool_call_id=tool_call_id,
1269+
metadata=metadata or {},
1270+
)
12441271
```
12451272

12461273

0 commit comments

Comments
 (0)