@@ -54,20 +54,25 @@ def get_tool_def(self, name: str) -> ToolDefinition | None:
54
54
except KeyError :
55
55
return None
56
56
57
- async def handle_call (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
57
+ async def handle_call (
58
+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
59
+ ) -> Any :
58
60
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
59
61
60
62
Args:
61
63
call: The tool call part to handle.
62
64
allow_partial: Whether to allow partial validation of the tool arguments.
65
+ wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
63
66
"""
64
67
if (tool := self .tools .get (call .tool_name )) and tool .tool_def .kind == 'output' :
65
68
# Output tool calls are not traced
66
- return await self ._call_tool (call , allow_partial )
69
+ return await self ._call_tool (call , allow_partial , wrap_validation_errors )
67
70
else :
68
- return await self ._call_tool_traced (call , allow_partial )
71
+ return await self ._call_tool_traced (call , allow_partial , wrap_validation_errors )
69
72
70
- async def _call_tool (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
73
+ async def _call_tool (
74
+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
75
+ ) -> Any :
71
76
name = call .tool_name
72
77
tool = self .tools .get (name )
73
78
try :
@@ -100,30 +105,35 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> A
100
105
if current_retry == max_retries :
101
106
raise UnexpectedModelBehavior (f'Tool { name !r} exceeded max retries count of { max_retries } ' ) from e
102
107
else :
103
- if isinstance (e , ValidationError ):
104
- m = _messages .RetryPromptPart (
105
- tool_name = name ,
106
- content = e .errors (include_url = False , include_context = False ),
107
- tool_call_id = call .tool_call_id ,
108
- )
109
- e = ToolRetryError (m )
110
- elif isinstance (e , ModelRetry ):
111
- m = _messages .RetryPromptPart (
112
- tool_name = name ,
113
- content = e .message ,
114
- tool_call_id = call .tool_call_id ,
115
- )
116
- e = ToolRetryError (m )
117
- else :
118
- assert_never (e )
108
+ if wrap_validation_errors :
109
+ if isinstance (e , ValidationError ):
110
+ m = _messages .RetryPromptPart (
111
+ tool_name = name ,
112
+ content = e .errors (include_url = False , include_context = False ),
113
+ tool_call_id = call .tool_call_id ,
114
+ )
115
+ e = ToolRetryError (m )
116
+ elif isinstance (e , ModelRetry ):
117
+ m = _messages .RetryPromptPart (
118
+ tool_name = name ,
119
+ content = e .message ,
120
+ tool_call_id = call .tool_call_id ,
121
+ )
122
+ e = ToolRetryError (m )
123
+ else :
124
+ assert_never (e )
125
+
126
+ if not allow_partial :
127
+ self .ctx .retries [name ] = current_retry + 1
119
128
120
- self .ctx .retries [name ] = current_retry + 1
121
129
raise e
122
130
else :
123
131
self .ctx .retries .pop (name , None )
124
132
return output
125
133
126
- async def _call_tool_traced (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
134
+ async def _call_tool_traced (
135
+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
136
+ ) -> Any :
127
137
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128
138
span_attributes = {
129
139
'gen_ai.tool.name' : call .tool_name ,
@@ -152,7 +162,7 @@ async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = Fals
152
162
}
153
163
with self .ctx .tracer .start_as_current_span ('running tool' , attributes = span_attributes ) as span :
154
164
try :
155
- tool_result = await self ._call_tool (call , allow_partial )
165
+ tool_result = await self ._call_tool (call , allow_partial , wrap_validation_errors )
156
166
except ToolRetryError as e :
157
167
part = e .tool_retry
158
168
if self .ctx .trace_include_content and span .is_recording ():
0 commit comments