Skip to content

Commit 311dfe3

Browse files
committed
Add parameter to tools. Make tools callable directly.
1 parent 6e1eae0 commit 311dfe3

File tree

6 files changed

+107
-40
lines changed

6 files changed

+107
-40
lines changed

rigging/chat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def inject_system_content(self, content: str) -> "Chat":
364364
self.messages[0].content += "\n\n" + content
365365
return self
366366

367-
def inject_tool_prompt(self, tools: t.Sequence[Tool], mode: ToolMode) -> "Chat":
367+
def inject_tool_prompt(self, tools: t.Sequence[Tool[..., t.Any]], mode: ToolMode) -> "Chat":
368368
"""
369369
Injects a default tool use prompt into the system prompt.
370370
@@ -695,7 +695,7 @@ def __init__(
695695
"""How to handle cache_control entries on messages."""
696696

697697
self.until_types: list[type[Model]] = []
698-
self.tools: list[Tool] = []
698+
self.tools: list[Tool[..., t.Any]] = []
699699
self.tool_mode: ToolMode = "auto"
700700
self.api_tool_choice: ApiToolChoice | None = None
701701
self.inject_tool_prompt = True
@@ -1066,7 +1066,9 @@ def wrap(self, func: t.Callable[[CallableT], CallableT]) -> "ChatPipeline":
10661066

10671067
def using(
10681068
self,
1069-
*tools: Tool | t.Callable[..., t.Any] | t.Sequence[Tool | t.Callable[..., t.Any]],
1069+
*tools: Tool[..., t.Any]
1070+
| t.Callable[..., t.Any]
1071+
| t.Sequence[Tool[..., t.Any] | t.Callable[..., t.Any]],
10701072
mode: ToolMode | None = None,
10711073
choice: ApiToolChoice | None = None,
10721074
max_depth: int = DEFAULT_MAX_DEPTH,

rigging/prompt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ class Prompt(t.Generic[P, R]):
460460

461461
params: GenerateParams | None = None
462462
"""The parameters to be used when generating chats for this prompt."""
463-
tools: list[Tool] = dataclasses.field(default_factory=list)
463+
tools: list[Tool[..., t.Any]] = dataclasses.field(default_factory=list)
464464
"""The API tools to be made available when generating chats for this prompt."""
465465
system_prompt: str | None = None
466466
"""A system prompt fragment to be injected into the messages before generation."""
@@ -1078,7 +1078,7 @@ def prompt(
10781078
pipeline: ChatPipeline | None = None,
10791079
generator: Generator | None = None,
10801080
generator_id: str | None = None,
1081-
tools: list[Tool | t.Callable[..., t.Any]] | None = None,
1081+
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
10821082
system_prompt: str | None = None,
10831083
) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]:
10841084
...
@@ -1092,7 +1092,7 @@ def prompt(
10921092
pipeline: ChatPipeline | None = None,
10931093
generator: Generator | None = None,
10941094
generator_id: str | None = None,
1095-
tools: list[Tool | t.Callable[..., t.Any]] | None = None,
1095+
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
10961096
system_prompt: str | None = None,
10971097
) -> Prompt[P, R]:
10981098
...
@@ -1106,7 +1106,7 @@ def prompt(
11061106
pipeline: ChatPipeline | None = None,
11071107
generator: Generator | None = None,
11081108
generator_id: str | None = None,
1109-
tools: list[Tool | t.Callable[..., t.Any]] | None = None,
1109+
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
11101110
system_prompt: str | None = None,
11111111
) -> Prompt[P, R]:
11121112
...
@@ -1119,7 +1119,7 @@ def prompt(
11191119
pipeline: ChatPipeline | None = None,
11201120
generator: Generator | None = None,
11211121
generator_id: str | None = None,
1122-
tools: list[Tool | t.Callable[..., t.Any]] | None = None,
1122+
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
11231123
system_prompt: str | None = None,
11241124
) -> (
11251125
t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]

rigging/tool/base.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
if t.TYPE_CHECKING:
2828
from rigging.message import Message
2929

30+
P = t.ParamSpec("P")
31+
R = t.TypeVar("R")
32+
3033
ToolMode = t.Literal["auto", "api", "xml", "json-in-xml"]
3134
"""
3235
How tool calls are handled.
@@ -39,7 +42,7 @@
3942

4043

4144
@dataclass
42-
class Tool:
45+
class Tool(t.Generic[P, R]):
4346
"""Base class for representing a tool to a generator."""
4447

4548
name: str
@@ -48,8 +51,16 @@ class Tool:
4851
"""A description of the tool."""
4952
parameters_schema: dict[str, t.Any]
5053
"""The JSON schema for the tool's parameters."""
51-
fn: t.Callable[..., t.Any]
54+
fn: t.Callable[P, R]
5255
"""The function to call."""
56+
catch: bool | set[type[Exception]] = False
57+
"""
58+
Whether to catch exceptions and return them as messages.
59+
60+
- `False`: Do not catch exceptions.
61+
- `True`: Catch all exceptions.
62+
- `list[type[Exception]]`: Catch only the specified exceptions.
63+
"""
5364

5465
_signature: inspect.Signature | None = field(default=None, init=False, repr=False)
5566
_type_adapter: TypeAdapter[t.Any] | None = field(default=None, init=False, repr=False)
@@ -68,11 +79,12 @@ class Tool:
6879
@classmethod
6980
def from_callable(
7081
cls,
71-
fn: t.Callable[..., t.Any],
82+
fn: t.Callable[P, R],
7283
*,
7384
name: str | None = None,
7485
description: str | None = None,
75-
) -> "Tool":
86+
catch: bool | t.Iterable[type[Exception]] = False,
87+
) -> "Tool[P, R]":
7688
from rigging.prompt import Prompt
7789

7890
fn_for_signature = fn
@@ -84,7 +96,7 @@ def from_callable(
8496

8597
if isinstance(fn, Prompt):
8698
fn_for_signature = fn.func # type: ignore [assignment]
87-
fn = fn.run
99+
fn = fn.run # type: ignore [assignment]
88100

89101
# In the case that we are recieving a bound function which is tracking
90102
# an originating prompt, unwrap the prompt and use it's function for
@@ -162,9 +174,11 @@ def empty_func(*args, **kwargs): # type: ignore [no-untyped-def] # noqa: ARG001
162174
description=description or fn_for_signature.__doc__ or "",
163175
parameters_schema=schema,
164176
fn=fn,
177+
catch=catch if isinstance(catch, bool) else set(catch),
165178
)
166179

167180
self._signature = signature
181+
self.__signature__ = signature # type: ignore [attr-defined]
168182

169183
# For handling API calls, we'll use the type adapter to validate
170184
# the arguments before calling the function
@@ -226,7 +240,7 @@ async def handle_tool_call(
226240
tool_call: The tool call to handle.
227241
228242
Returns:
229-
The message to send back to the generator or None if tool calling should not proceed.
243+
The message to send back to the generator or `None` if iterative tool calling should not proceed any further.
230244
"""
231245

232246
from rigging.message import ContentText, ContentTypes, Message
@@ -248,12 +262,12 @@ async def handle_tool_call(
248262

249263
# Load + validate arguments
250264

251-
args: dict[str, t.Any]
265+
kwargs: dict[str, t.Any]
252266
if isinstance(tool_call, ApiToolCall | JsonInXmlToolCall):
253-
args = json.loads(tool_call_parameters)
267+
kwargs = json.loads(tool_call_parameters)
254268

255269
if self._type_adapter is not None:
256-
args = self._type_adapter.validate_python(args)
270+
kwargs = self._type_adapter.validate_python(kwargs)
257271

258272
elif isinstance(tool_call, XmlToolCall):
259273
parsed = self.model.from_text(
@@ -268,18 +282,26 @@ async def handle_tool_call(
268282
# argument object instances. We'll just flatten the
269283
# model into a dictionary for the function call.
270284

271-
args = {
285+
kwargs = {
272286
field_name: getattr(parameters, field_name, None)
273287
for field_name in self.model.model_fields
274288
}
275289

276-
span.set_attribute("arguments", args)
290+
span.set_attribute("arguments", kwargs)
277291

278292
# Call the function
279293

280-
result = self.fn(**args)
281-
if inspect.isawaitable(result):
282-
result = await result
294+
try:
295+
result: t.Any = self.fn(**kwargs) # type: ignore [call-arg]
296+
if inspect.isawaitable(result):
297+
result = await result
298+
except Exception as e: # noqa: BLE001
299+
if self.catch is True or (
300+
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
301+
):
302+
result = f'<error type="{e.__class__.__name__}">:{e}</error>'
303+
else:
304+
raise
283305

284306
span.set_attribute("result", result)
285307

@@ -331,6 +353,9 @@ async def handle_tool_call(
331353

332354
return message, True
333355

356+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
357+
return self.fn(*args, **kwargs)
358+
334359

335360
# Decorator
336361

@@ -342,42 +367,49 @@ def tool(
342367
*,
343368
name: str | None = None,
344369
description: str | None = None,
345-
) -> t.Callable[[t.Callable[..., t.Any]], Tool]:
370+
catch: bool | t.Iterable[type[Exception]] = False,
371+
) -> t.Callable[[t.Callable[P, R]], Tool[P, R]]:
346372
...
347373

348374

349375
@t.overload
350376
def tool(
351-
func: t.Callable[..., t.Any],
377+
func: t.Callable[P, R],
352378
/,
353379
*,
354380
name: str | None = None,
355381
description: str | None = None,
356-
) -> Tool:
382+
catch: bool | t.Iterable[type[Exception]] = False,
383+
) -> Tool[P, R]:
357384
...
358385

359386

360387
def tool(
361-
func: t.Callable[..., t.Any] | None = None,
388+
func: t.Callable[P, R] | None = None,
362389
/,
363390
*,
364391
name: str | None = None,
365392
description: str | None = None,
366-
) -> t.Callable[[t.Callable[..., t.Any]], Tool] | Tool:
393+
catch: bool | t.Iterable[type[Exception]] = False,
394+
) -> t.Callable[[t.Callable[P, R]], Tool[P, R]] | Tool[P, R]:
367395
"""
368396
Decorator for creating a Tool, useful for overriding a name or description.
369397
370398
Args:
371399
func: The function to wrap.
372400
name: The name of the tool.
373401
description: The description of the tool.
402+
catch: Whether to catch exceptions and return them as messages.
403+
- `False`: Do not catch exceptions.
404+
- `True`: Catch all exceptions.
405+
- `list[type[Exception]]`: Catch only the specified exceptions.
374406
375407
Returns:
376408
The decorated Tool object.
377409
"""
378410

379-
def make_tool(func: t.Callable[..., t.Any]) -> Tool:
380-
return Tool.from_callable(func, name=name, description=description)
411+
def make_tool(func: t.Callable[..., t.Any]) -> Tool[P, R]:
412+
return Tool.from_callable(func, name=name, description=description, catch=catch)
381413

382414
if func is not None:
383415
return make_tool(func)

rigging/tool/mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class MCPClient:
6969
"""The transport to use"""
7070
connection: StdioConnection | SSEConnection
7171
"""Connection configuration"""
72-
tools: list[Tool]
72+
tools: list[Tool[..., t.Any]]
7373
"""A list of tools available on the server"""
7474

7575
def __init__(self, transport: Transport, connection: StdioConnection | SSEConnection) -> None:

rigging/tool/robopages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def execute_on_server(**kwargs: t.Any) -> t.Any:
4444
return execute_on_server
4545

4646

47-
def robopages(url: str, *, name_filter: str | None = None) -> list[Tool]:
47+
def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.Any]]:
4848
"""
4949
Create a list of tools from a Robopages server.
5050
@@ -83,7 +83,7 @@ def robopages(url: str, *, name_filter: str | None = None) -> list[Tool]:
8383

8484
logger.info(f"Fetched {len(tool_definitions)} functions from Robopages ({url})")
8585

86-
tools: list[Tool] = []
86+
tools: list[Tool[..., t.Any]] = []
8787
for definition in tool_definitions:
8888
function = definition.function
8989

0 commit comments

Comments
 (0)