@@ -175,17 +175,14 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
175
175
Returns:
176
176
Structured output.
177
177
"""
178
- if not result or not isinstance (result [0 ], ChatGeneration ):
178
+ if (
179
+ not result
180
+ or not isinstance (result [0 ], ChatGeneration )
181
+ or not isinstance (result [0 ].message , AIMessage )
182
+ or not result [0 ].message .tool_calls
183
+ ):
179
184
return None if self .first_tool_only else []
180
- message = result [0 ].message
181
- if len (message .content ) > 0 :
182
- tool_calls : List = []
183
- else :
184
- content = cast (AIMessage , message )
185
- _tool_calls = [dict (tc ) for tc in content .tool_calls ]
186
- # Map tool call id to index
187
- id_to_index = {block ["id" ]: i for i , block in enumerate (_tool_calls )}
188
- tool_calls = [{** tc , "index" : id_to_index [tc ["id" ]]} for tc in _tool_calls ]
185
+ tool_calls : Any = result [0 ].message .tool_calls
189
186
if self .pydantic_schemas :
190
187
tool_calls = [self ._pydantic_parse (tc ) for tc in tool_calls ]
191
188
elif self .args_only :
@@ -194,11 +191,11 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
194
191
pass
195
192
196
193
if self .first_tool_only :
197
- return tool_calls [0 ] if tool_calls else None
194
+ return tool_calls [0 ]
198
195
else :
199
- return [ tool_call for tool_call in tool_calls ]
196
+ return tool_calls
200
197
201
- def _pydantic_parse (self , tool_call : dict ) -> BaseModel :
198
+ def _pydantic_parse (self , tool_call : ToolCall ) -> BaseModel :
202
199
cls_ = {schema .__name__ : schema for schema in self .pydantic_schemas or []}[
203
200
tool_call ["name" ]
204
201
]
0 commit comments