Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class Joke(BaseModel):
max_tokens: Optional[int] = None
"""Max tokens to generate."""

stop_sequences: Optional[List[str]] = Field(None, alias="stop")
stop_sequences: Optional[List[str]] = Field(default=None, alias="stop")
"""Stop generation if any of these substrings occurs."""

temperature: Optional[float] = None
Expand Down Expand Up @@ -308,12 +308,17 @@ class Joke(BaseModel):
have an ARN associated with them.
"""

endpoint_url: Optional[str] = Field(None, alias="base_url")
endpoint_url: Optional[str] = Field(default=None, alias="base_url")
"""Needed if you don't want to default to us-east-1 endpoint"""

config: Any = None
"""An optional botocore.config.Config instance to pass to the client."""

formatted_tools: List[
Dict[Literal["toolSpec"], Dict[str, Union[Dict[str, Any], str]]]
] = Field(default_factory=list, exclude=True)
""""Formatted tools to be stored and used in the toolConfig parameter."""

class Config:
"""Configuration for this pydantic object."""

Expand Down Expand Up @@ -413,7 +418,8 @@ def bind_tools(
) -> Runnable[LanguageModelInput, BaseMessage]:
if tool_choice:
kwargs["tool_choice"] = _format_tool_choice(tool_choice)
return self.bind(tools=_format_tools(tools), **kwargs)
self.formatted_tools = _format_tools(tools)
return self.bind(tools=self.formatted_tools, **kwargs)

def with_structured_output(
self,
Expand Down Expand Up @@ -467,8 +473,7 @@ def _converse_params(
}
if not toolConfig and tools:
toolChoice = _format_tool_choice(toolChoice) if toolChoice else None
toolConfig = {"tools": _format_tools(tools), "toolChoice": toolChoice}

toolConfig = {"tools": self.formatted_tools, "toolChoice": toolChoice}
return _drop_none(
{
"modelId": modelId or self.model_id,
Expand Down Expand Up @@ -648,7 +653,7 @@ def _anthropic_to_bedrock(
{
"toolUse": {
"toolUseId": block["id"],
"input": block["input"],
"input": _try_to_convert_to_dict(block["input"]),
"name": block["name"],
}
}
Expand Down Expand Up @@ -852,3 +857,13 @@ def _format_openai_image_url(image_url: str) -> Dict:
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}


def _try_to_convert_to_dict(tool_use_input: Any) -> Any:
"""Attempt to convert the toolUse.input to a dictionary."""
if isinstance(tool_use_input, str):
try:
return json.loads(tool_use_input)
except json.JSONDecodeError:
return tool_use_input
return tool_use_input