2727if t .TYPE_CHECKING :
2828 from rigging .message import Message
2929
30+ P = t .ParamSpec ("P" )
31+ R = t .TypeVar ("R" )
32+
3033ToolMode = t .Literal ["auto" , "api" , "xml" , "json-in-xml" ]
3134"""
3235How tool calls are handled.
3942
4043
4144@dataclass
42- class Tool :
45+ class Tool ( t . Generic [ P , R ]) :
4346 """Base class for representing a tool to a generator."""
4447
4548 name : str
@@ -48,8 +51,16 @@ class Tool:
4851 """A description of the tool."""
4952 parameters_schema : dict [str , t .Any ]
5053 """The JSON schema for the tool's parameters."""
51- fn : t .Callable [..., t . Any ]
54+ fn : t .Callable [P , R ]
5255 """The function to call."""
56+ catch : bool | set [type [Exception ]] = False
57+ """
58+ Whether to catch exceptions and return them as messages.
59+
60+ - `False`: Do not catch exceptions.
61+ - `True`: Catch all exceptions.
62+ - `list[type[Exception]]`: Catch only the specified exceptions.
63+ """
5364
5465 _signature : inspect .Signature | None = field (default = None , init = False , repr = False )
5566 _type_adapter : TypeAdapter [t .Any ] | None = field (default = None , init = False , repr = False )
@@ -68,11 +79,12 @@ class Tool:
6879 @classmethod
6980 def from_callable (
7081 cls ,
71- fn : t .Callable [..., t . Any ],
82+ fn : t .Callable [P , R ],
7283 * ,
7384 name : str | None = None ,
7485 description : str | None = None ,
75- ) -> "Tool" :
86+ catch : bool | t .Iterable [type [Exception ]] = False ,
87+ ) -> "Tool[P, R]" :
7688 from rigging .prompt import Prompt
7789
7890 fn_for_signature = fn
@@ -84,7 +96,7 @@ def from_callable(
8496
8597 if isinstance (fn , Prompt ):
8698 fn_for_signature = fn .func # type: ignore [assignment]
87- fn = fn .run
99+ fn = fn .run # type: ignore [assignment]
88100
89101 # In the case that we are recieving a bound function which is tracking
90102 # an originating prompt, unwrap the prompt and use it's function for
@@ -162,9 +174,11 @@ def empty_func(*args, **kwargs): # type: ignore [no-untyped-def] # noqa: ARG001
162174 description = description or fn_for_signature .__doc__ or "" ,
163175 parameters_schema = schema ,
164176 fn = fn ,
177+ catch = catch if isinstance (catch , bool ) else set (catch ),
165178 )
166179
167180 self ._signature = signature
181+ self .__signature__ = signature # type: ignore [attr-defined]
168182
169183 # For handling API calls, we'll use the type adapter to validate
170184 # the arguments before calling the function
@@ -226,7 +240,7 @@ async def handle_tool_call(
226240 tool_call: The tool call to handle.
227241
228242 Returns:
229- The message to send back to the generator or None if tool calling should not proceed.
243+ The message to send back to the generator or ` None` if iterative tool calling should not proceed any further .
230244 """
231245
232246 from rigging .message import ContentText , ContentTypes , Message
@@ -248,12 +262,12 @@ async def handle_tool_call(
248262
249263 # Load + validate arguments
250264
251- args : dict [str , t .Any ]
265+ kwargs : dict [str , t .Any ]
252266 if isinstance (tool_call , ApiToolCall | JsonInXmlToolCall ):
253- args = json .loads (tool_call_parameters )
267+ kwargs = json .loads (tool_call_parameters )
254268
255269 if self ._type_adapter is not None :
256- args = self ._type_adapter .validate_python (args )
270+ kwargs = self ._type_adapter .validate_python (kwargs )
257271
258272 elif isinstance (tool_call , XmlToolCall ):
259273 parsed = self .model .from_text (
@@ -268,18 +282,26 @@ async def handle_tool_call(
268282 # argument object instances. We'll just flatten the
269283 # model into a dictionary for the function call.
270284
271- args = {
285+ kwargs = {
272286 field_name : getattr (parameters , field_name , None )
273287 for field_name in self .model .model_fields
274288 }
275289
276- span .set_attribute ("arguments" , args )
290+ span .set_attribute ("arguments" , kwargs )
277291
278292 # Call the function
279293
280- result = self .fn (** args )
281- if inspect .isawaitable (result ):
282- result = await result
294+ try :
295+ result : t .Any = self .fn (** kwargs ) # type: ignore [call-arg]
296+ if inspect .isawaitable (result ):
297+ result = await result
298+ except Exception as e : # noqa: BLE001
299+ if self .catch is True or (
300+ not isinstance (self .catch , bool ) and isinstance (e , tuple (self .catch ))
301+ ):
302+ result = f'<error type="{ e .__class__ .__name__ } ">:{ e } </error>'
303+ else :
304+ raise
283305
284306 span .set_attribute ("result" , result )
285307
@@ -331,6 +353,9 @@ async def handle_tool_call(
331353
332354 return message , True
333355
356+ def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> R :
357+ return self .fn (* args , ** kwargs )
358+
334359
335360# Decorator
336361
@@ -342,42 +367,49 @@ def tool(
342367 * ,
343368 name : str | None = None ,
344369 description : str | None = None ,
345- ) -> t .Callable [[t .Callable [..., t .Any ]], Tool ]:
370+ catch : bool | t .Iterable [type [Exception ]] = False ,
371+ ) -> t .Callable [[t .Callable [P , R ]], Tool [P , R ]]:
346372 ...
347373
348374
349375@t .overload
350376def tool (
351- func : t .Callable [..., t . Any ],
377+ func : t .Callable [P , R ],
352378 / ,
353379 * ,
354380 name : str | None = None ,
355381 description : str | None = None ,
356- ) -> Tool :
382+ catch : bool | t .Iterable [type [Exception ]] = False ,
383+ ) -> Tool [P , R ]:
357384 ...
358385
359386
360387def tool (
361- func : t .Callable [..., t . Any ] | None = None ,
388+ func : t .Callable [P , R ] | None = None ,
362389 / ,
363390 * ,
364391 name : str | None = None ,
365392 description : str | None = None ,
366- ) -> t .Callable [[t .Callable [..., t .Any ]], Tool ] | Tool :
393+ catch : bool | t .Iterable [type [Exception ]] = False ,
394+ ) -> t .Callable [[t .Callable [P , R ]], Tool [P , R ]] | Tool [P , R ]:
367395 """
368396 Decorator for creating a Tool, useful for overriding a name or description.
369397
370398 Args:
371399 func: The function to wrap.
372400 name: The name of the tool.
373401 description: The description of the tool.
402+ catch: Whether to catch exceptions and return them as messages.
403+ - `False`: Do not catch exceptions.
404+ - `True`: Catch all exceptions.
405+ - `list[type[Exception]]`: Catch only the specified exceptions.
374406
375407 Returns:
376408 The decorated Tool object.
377409 """
378410
379- def make_tool (func : t .Callable [..., t .Any ]) -> Tool :
380- return Tool .from_callable (func , name = name , description = description )
411+ def make_tool (func : t .Callable [..., t .Any ]) -> Tool [ P , R ] :
412+ return Tool .from_callable (func , name = name , description = description , catch = catch )
381413
382414 if func is not None :
383415 return make_tool (func )
0 commit comments