Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.messages.utils import convert_to_openai_messages
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -285,6 +286,50 @@ def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str:
)


def _convert_one_message_to_text_openai(message: BaseMessage) -> str:
if isinstance(message, SystemMessage):
message_text = (
f"<|start|>system<|message|>{message.content}<|end|>"
)
elif isinstance(message, ChatMessage):
# developer role messages
message_text = (
f"<|start|>{message.role}<|message|>{message.content}<|end|>"
)
elif isinstance(message, HumanMessage):
message_text = (
f"<|start|>user<|message|>{message.content}<|end|>"
)
elif isinstance(message, AIMessage):
message_text = (
f"<|start|>assistant<|channel|>final<|message|>{message.content}<|end|>"
)
elif isinstance(message, ToolMessage):
# TODO: Tool messages in the OpenAI format should use "<|start|>{toolname} to=assistant<|message|>"
# Need to extract the tool name from the ToolMessage content or tool_call_id
# For now using generic "to=assistant" format as placeholder until we implement tool calling
# Will be resolved in follow-up PR with full tool support
message_text = (
f"<|start|>to=assistant<|channel|>commentary<|message|>{message.content}<|end|>"
)
else:
raise ValueError(f"Got unknown type {message}")

return message_text


def convert_messages_to_prompt_openai(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a Harmony format prompt for OpenAI Responses API."""

prompt = "\n"
for message in messages:
prompt += _convert_one_message_to_text_openai(message)

prompt += "<|start|>assistant\n\n"

return prompt


def _format_image(image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
Expand Down Expand Up @@ -640,6 +685,8 @@ def convert_messages_to_prompt(
)
elif provider == "writer":
prompt = convert_messages_to_prompt_writer(messages=messages)
elif provider == "openai":
prompt = convert_messages_to_prompt_openai(messages=messages)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
Expand All @@ -649,10 +696,11 @@ def convert_messages_to_prompt(
@classmethod
def format_messages(
cls, provider: str, messages: List[BaseMessage]
) -> Tuple[Optional[str], List[Dict]]:
) -> Union[Tuple[Optional[str], List[Dict]], List[Dict]]:
if provider == "anthropic":
return _format_anthropic_messages(messages)

elif provider == "openai":
return convert_to_openai_messages(messages)
raise NotImplementedError(
f"Provider {provider} not supported for format_messages"
)
Expand Down Expand Up @@ -777,6 +825,8 @@ def _stream(
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
elif provider == "openai":
formatted_messages = ChatPromptAdapter.format_messages(provider, messages)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages, model=self._get_base_model()
Expand Down Expand Up @@ -876,6 +926,8 @@ def _generate(
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
elif provider == "openai":
formatted_messages = ChatPromptAdapter.format_messages(provider, messages)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages, model=self._get_base_model()
Expand Down
11 changes: 11 additions & 0 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def prepare_input(
input_body["max_tokens"] = max_tokens
elif provider == "writer":
input_body["max_tokens"] = max_tokens
elif provider == "openai":
input_body["max_output_tokens"] = max_tokens
else:
# TODO: Add AI21 support, param depends on specific model.
pass
Expand All @@ -391,6 +393,13 @@ def prepare_input(
input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens
if temperature is not None:
input_body["textGenerationConfig"]["temperature"] = temperature

elif provider == "openai":
input_body["messages"] = messages
if max_tokens:
input_body["max_tokens"] = max_tokens
if temperature is not None:
input_body["temperature"] = temperature
else:
input_body["inputText"] = prompt

Expand Down Expand Up @@ -442,6 +451,8 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
text = response_body.get("generation")
elif provider == "mistral":
text = response_body.get("outputs")[0].get("text")
elif provider == "openai":
text = response_body.get("choices")[0].get("message").get("content")
else:
text = response_body.get("results")[0].get("outputText")

Expand Down