Skip to content

Commit 9532515

Browse files
committed
aws[minor]: Add ChatModel that uses Bedrock.converse API
1 parent 99409d8 commit 9532515

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

libs/aws/langchain_aws/function_calling.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,14 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
175175
Returns:
176176
Structured output.
177177
"""
178-
if not result or not isinstance(result[0], ChatGeneration):
178+
if (
179+
not result
180+
or not isinstance(result[0], ChatGeneration)
181+
or not isinstance(result[0].message, AIMessage)
182+
or not result[0].message.tool_calls
183+
):
179184
return None if self.first_tool_only else []
180-
message = result[0].message
181-
if len(message.content) > 0:
182-
tool_calls: List = []
183-
else:
184-
content = cast(AIMessage, message)
185-
_tool_calls = [dict(tc) for tc in content.tool_calls]
186-
# Map tool call id to index
187-
id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)}
188-
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
185+
tool_calls: Any = result[0].message.tool_calls
189186
if self.pydantic_schemas:
190187
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
191188
elif self.args_only:
@@ -194,11 +191,11 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
194191
pass
195192

196193
if self.first_tool_only:
197-
return tool_calls[0] if tool_calls else None
194+
return tool_calls[0]
198195
else:
199-
return [tool_call for tool_call in tool_calls]
196+
return tool_calls
200197

201-
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
198+
def _pydantic_parse(self, tool_call: ToolCall) -> BaseModel:
202199
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
203200
tool_call["name"]
204201
]

0 commit comments

Comments
 (0)