diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 914b1800..e68cfdcc 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -35,6 +35,7 @@ ) from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.messages.utils import convert_to_openai_messages from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool @@ -285,6 +286,50 @@ def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str: ) +def _convert_one_message_to_text_openai(message: BaseMessage) -> str: + if isinstance(message, SystemMessage): + message_text = ( + f"<|start|>system<|message|>{message.content}<|end|>" + ) + elif isinstance(message, ChatMessage): + # developer role messages + message_text = ( + f"<|start|>{message.role}<|message|>{message.content}<|end|>" + ) + elif isinstance(message, HumanMessage): + message_text = ( + f"<|start|>user<|message|>{message.content}<|end|>" + ) + elif isinstance(message, AIMessage): + message_text = ( + f"<|start|>assistant<|channel|>final<|message|>{message.content}<|end|>" + ) + elif isinstance(message, ToolMessage): + # TODO: Tool messages in the OpenAI format should use "<|start|>{toolname} to=assistant<|message|>" + # Need to extract the tool name from the ToolMessage content or tool_call_id + # For now using generic "to=assistant" format as placeholder until we implement tool calling + # Will be resolved in follow-up PR with full tool support + message_text = ( + f"<|start|>to=assistant<|channel|>commentary<|message|>{message.content}<|end|>" + ) + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + +def convert_messages_to_prompt_openai(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a Harmony format prompt for OpenAI Responses API.""" + + prompt = "\n" + for message in messages: + prompt += _convert_one_message_to_text_openai(message) + + prompt += "<|start|>assistant\n\n" + + return prompt + + def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} @@ -640,6 +685,8 @@ def convert_messages_to_prompt( ) elif provider == "writer": prompt = convert_messages_to_prompt_writer(messages=messages) + elif provider == "openai": + prompt = convert_messages_to_prompt_openai(messages=messages) else: raise NotImplementedError( f"Provider {provider} model does not support chat." @@ -649,10 +696,11 @@ def convert_messages_to_prompt( @classmethod def format_messages( cls, provider: str, messages: List[BaseMessage] - ) -> Tuple[Optional[str], List[Dict]]: + ) -> Union[Tuple[Optional[str], List[Dict]], List[Dict]]: if provider == "anthropic": return _format_anthropic_messages(messages) - + elif provider == "openai": + return convert_to_openai_messages(messages) raise NotImplementedError( f"Provider {provider} not supported for format_messages" ) @@ -777,6 +825,8 @@ def _stream( system = self.system_prompt_with_tools + f"\n{system}" else: system = self.system_prompt_with_tools + elif provider == "openai": + formatted_messages = ChatPromptAdapter.format_messages(provider, messages) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages, model=self._get_base_model() @@ -876,6 +926,8 @@ def _generate( system = self.system_prompt_with_tools + f"\n{system}" else: system = self.system_prompt_with_tools + elif provider == "openai": + formatted_messages = ChatPromptAdapter.format_messages(provider, messages) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages, model=self._get_base_model() diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 3204363d..0542a528 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -377,6 +377,8 @@ def prepare_input( input_body["max_tokens"] = max_tokens elif provider == "writer": input_body["max_tokens"] = max_tokens + elif provider == "openai": + input_body["max_output_tokens"] = max_tokens else: # TODO: Add AI21 support, param depends on specific model. pass @@ -391,6 +393,13 @@ def prepare_input( input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens if temperature is not None: input_body["textGenerationConfig"]["temperature"] = temperature + + elif provider == "openai": + input_body["messages"] = messages + if max_tokens: + input_body["max_tokens"] = max_tokens + if temperature is not None: + input_body["temperature"] = temperature else: input_body["inputText"] = prompt @@ -442,6 +451,8 @@ def prepare_output(cls, provider: str, response: Any) -> dict: text = response_body.get("generation") elif provider == "mistral": text = response_body.get("outputs")[0].get("text") + elif provider == "openai": + text = response_body.get("choices")[0].get("message").get("content") else: text = response_body.get("results")[0].get("outputText")