Skip to content

Commit 43b8ebd

Browse files
michaelnchin3coins
andauthored
Support OpenAI gpt-oss models on ChatBedrock (#581)
Adds base support for using [gpt-oss](https://openai.com/index/introducing-gpt-oss/) on ChatBedrock. Tools and streaming support to be implemented in later PRs. Note that Bedrock currently uses the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) for gpt-oss models, and not the newer [Responses API](https://platform.openai.com/docs/api-reference/responses), so we have to format the input (i.e. messages only) and parameters in the request body accordingly. --------- Co-authored-by: Piyush Jain <[email protected]>
1 parent 9ab4dbe commit 43b8ebd

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from langchain_core.messages.ai import UsageMetadata
3737
from langchain_core.messages.tool import ToolCall, ToolMessage
38+
from langchain_core.messages.utils import convert_to_openai_messages
3839
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
3940
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
4041
from langchain_core.tools import BaseTool
@@ -285,6 +286,50 @@ def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str:
285286
)
286287

287288

289+
def _convert_one_message_to_text_openai(message: BaseMessage) -> str:
290+
if isinstance(message, SystemMessage):
291+
message_text = (
292+
f"<|start|>system<|message|>{message.content}<|end|>"
293+
)
294+
elif isinstance(message, ChatMessage):
295+
# developer role messages
296+
message_text = (
297+
f"<|start|>{message.role}<|message|>{message.content}<|end|>"
298+
)
299+
elif isinstance(message, HumanMessage):
300+
message_text = (
301+
f"<|start|>user<|message|>{message.content}<|end|>"
302+
)
303+
elif isinstance(message, AIMessage):
304+
message_text = (
305+
f"<|start|>assistant<|channel|>final<|message|>{message.content}<|end|>"
306+
)
307+
elif isinstance(message, ToolMessage):
308+
# TODO: Tool messages in the OpenAI format should use "<|start|>{toolname} to=assistant<|message|>"
309+
# Need to extract the tool name from the ToolMessage content or tool_call_id
310+
# For now using generic "to=assistant" format as placeholder until we implement tool calling
311+
# Will be resolved in follow-up PR with full tool support
312+
message_text = (
313+
f"<|start|>to=assistant<|channel|>commentary<|message|>{message.content}<|end|>"
314+
)
315+
else:
316+
raise ValueError(f"Got unknown type {message}")
317+
318+
return message_text
319+
320+
321+
def convert_messages_to_prompt_openai(messages: List[BaseMessage]) -> str:
322+
"""Convert a list of messages to a Harmony format prompt for OpenAI Responses API."""
323+
324+
prompt = "\n"
325+
for message in messages:
326+
prompt += _convert_one_message_to_text_openai(message)
327+
328+
prompt += "<|start|>assistant\n\n"
329+
330+
return prompt
331+
332+
288333
def _format_image(image_url: str) -> Dict:
289334
"""
290335
Formats an image of format data:image/jpeg;base64,{b64_string}
@@ -640,6 +685,8 @@ def convert_messages_to_prompt(
640685
)
641686
elif provider == "writer":
642687
prompt = convert_messages_to_prompt_writer(messages=messages)
688+
elif provider == "openai":
689+
prompt = convert_messages_to_prompt_openai(messages=messages)
643690
else:
644691
raise NotImplementedError(
645692
f"Provider {provider} model does not support chat."
@@ -649,10 +696,11 @@ def convert_messages_to_prompt(
649696
@classmethod
650697
def format_messages(
651698
cls, provider: str, messages: List[BaseMessage]
652-
) -> Tuple[Optional[str], List[Dict]]:
699+
) -> Union[Tuple[Optional[str], List[Dict]], List[Dict]]:
653700
if provider == "anthropic":
654701
return _format_anthropic_messages(messages)
655-
702+
elif provider == "openai":
703+
return convert_to_openai_messages(messages)
656704
raise NotImplementedError(
657705
f"Provider {provider} not supported for format_messages"
658706
)
@@ -777,6 +825,8 @@ def _stream(
777825
system = self.system_prompt_with_tools + f"\n{system}"
778826
else:
779827
system = self.system_prompt_with_tools
828+
elif provider == "openai":
829+
formatted_messages = ChatPromptAdapter.format_messages(provider, messages)
780830
else:
781831
prompt = ChatPromptAdapter.convert_messages_to_prompt(
782832
provider=provider, messages=messages, model=self._get_base_model()
@@ -876,6 +926,8 @@ def _generate(
876926
system = self.system_prompt_with_tools + f"\n{system}"
877927
else:
878928
system = self.system_prompt_with_tools
929+
elif provider == "openai":
930+
formatted_messages = ChatPromptAdapter.format_messages(provider, messages)
879931
else:
880932
prompt = ChatPromptAdapter.convert_messages_to_prompt(
881933
provider=provider, messages=messages, model=self._get_base_model()

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ def prepare_input(
377377
input_body["max_tokens"] = max_tokens
378378
elif provider == "writer":
379379
input_body["max_tokens"] = max_tokens
380+
elif provider == "openai":
381+
input_body["max_output_tokens"] = max_tokens
380382
else:
381383
# TODO: Add AI21 support, param depends on specific model.
382384
pass
@@ -391,6 +393,13 @@ def prepare_input(
391393
input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens
392394
if temperature is not None:
393395
input_body["textGenerationConfig"]["temperature"] = temperature
396+
397+
elif provider == "openai":
398+
input_body["messages"] = messages
399+
if max_tokens:
400+
input_body["max_tokens"] = max_tokens
401+
if temperature is not None:
402+
input_body["temperature"] = temperature
394403
else:
395404
input_body["inputText"] = prompt
396405

@@ -442,6 +451,8 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
442451
text = response_body.get("generation")
443452
elif provider == "mistral":
444453
text = response_body.get("outputs")[0].get("text")
454+
elif provider == "openai":
455+
text = response_body.get("choices")[0].get("message").get("content")
445456
else:
446457
text = response_body.get("results")[0].get("outputText")
447458

0 commit comments

Comments
 (0)