Skip to content

Commit 88d5f3e

Browse files
authored
openai[patch]: allow specification of output format for Responses API (#31686)
1 parent 59c2b81 commit 88d5f3e

File tree

14 files changed

+328
-35
lines changed

14 files changed

+328
-35
lines changed

libs/langchain/tests/unit_tests/chat_models/test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def test_configurable() -> None:
113113
"openai_api_base": None,
114114
"openai_organization": None,
115115
"openai_proxy": None,
116+
"output_version": "v0",
116117
"request_timeout": None,
117118
"max_retries": None,
118119
"presence_penalty": None,

libs/partners/openai/langchain_openai/chat_models/_compat.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def _convert_to_v03_ai_message(
128128
else:
129129
new_content.append(block)
130130
message.content = new_content
131+
if isinstance(message.id, str) and message.id.startswith("resp_"):
132+
message.id = None
131133
else:
132134
pass
133135

@@ -137,13 +139,29 @@ def _convert_to_v03_ai_message(
137139
def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
138140
"""Convert an old-style v0.3 AIMessage into the new content-block format."""
139141
# Only update ChatOpenAI v0.3 AIMessages
140-
if not (
142+
# TODO: structure provenance into AIMessage
143+
is_chatopenai_v03 = (
141144
isinstance(message.content, list)
142145
and all(isinstance(b, dict) for b in message.content)
143-
) or not any(
144-
item in message.additional_kwargs
145-
for item in ["reasoning", "tool_outputs", "refusal", _FUNCTION_CALL_IDS_MAP_KEY]
146-
):
146+
) and (
147+
any(
148+
item in message.additional_kwargs
149+
for item in [
150+
"reasoning",
151+
"tool_outputs",
152+
"refusal",
153+
_FUNCTION_CALL_IDS_MAP_KEY,
154+
]
155+
)
156+
or (
157+
isinstance(message.id, str)
158+
and message.id.startswith("msg_")
159+
and (response_id := message.response_metadata.get("id"))
160+
and isinstance(response_id, str)
161+
and response_id.startswith("resp_")
162+
)
163+
)
164+
if not is_chatopenai_v03:
147165
return message
148166

149167
content_order = [

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,25 @@ class BaseChatOpenAI(BaseChatModel):
649649
.. versionadded:: 0.3.9
650650
"""
651651

652+
output_version: Literal["v0", "responses/v1"] = "v0"
653+
"""Version of AIMessage output format to use.
654+
655+
This field is used to roll-out new output formats for chat model AIMessages
656+
in a backwards-compatible way.
657+
658+
Supported values:
659+
660+
- ``"v0"``: AIMessage format as of langchain-openai 0.3.x.
661+
- ``"responses/v1"``: Formats Responses API output
662+
items into AIMessage content blocks.
663+
664+
Currently only impacts the Responses API. ``output_version="responses/v1"`` is
665+
recommended.
666+
667+
.. versionadded:: 0.3.25
668+
669+
"""
670+
652671
model_config = ConfigDict(populate_by_name=True)
653672

654673
@model_validator(mode="before")
@@ -903,6 +922,7 @@ def _stream_responses(
903922
schema=original_schema_obj,
904923
metadata=metadata,
905924
has_reasoning=has_reasoning,
925+
output_version=self.output_version,
906926
)
907927
if generation_chunk:
908928
if run_manager:
@@ -957,6 +977,7 @@ async def _astream_responses(
957977
schema=original_schema_obj,
958978
metadata=metadata,
959979
has_reasoning=has_reasoning,
980+
output_version=self.output_version,
960981
)
961982
if generation_chunk:
962983
if run_manager:
@@ -1096,7 +1117,10 @@ def _generate(
10961117
else:
10971118
response = self.root_client.responses.create(**payload)
10981119
return _construct_lc_result_from_responses_api(
1099-
response, schema=original_schema_obj, metadata=generation_info
1120+
response,
1121+
schema=original_schema_obj,
1122+
metadata=generation_info,
1123+
output_version=self.output_version,
11001124
)
11011125
elif self.include_response_headers:
11021126
raw_response = self.client.with_raw_response.create(**payload)
@@ -1109,6 +1133,8 @@ def _generate(
11091133
def _use_responses_api(self, payload: dict) -> bool:
11101134
if isinstance(self.use_responses_api, bool):
11111135
return self.use_responses_api
1136+
elif self.output_version == "responses/v1":
1137+
return True
11121138
elif self.include is not None:
11131139
return True
11141140
elif self.reasoning is not None:
@@ -1327,7 +1353,10 @@ async def _agenerate(
13271353
else:
13281354
response = await self.root_async_client.responses.create(**payload)
13291355
return _construct_lc_result_from_responses_api(
1330-
response, schema=original_schema_obj, metadata=generation_info
1356+
response,
1357+
schema=original_schema_obj,
1358+
metadata=generation_info,
1359+
output_version=self.output_version,
13311360
)
13321361
elif self.include_response_headers:
13331362
raw_response = await self.async_client.with_raw_response.create(**payload)
@@ -3540,6 +3569,7 @@ def _construct_lc_result_from_responses_api(
35403569
response: Response,
35413570
schema: Optional[type[_BM]] = None,
35423571
metadata: Optional[dict] = None,
3572+
output_version: Literal["v0", "responses/v1"] = "v0",
35433573
) -> ChatResult:
35443574
"""Construct ChatResponse from OpenAI Response API response."""
35453575
if response.error:
@@ -3676,7 +3706,10 @@ def _construct_lc_result_from_responses_api(
36763706
tool_calls=tool_calls,
36773707
invalid_tool_calls=invalid_tool_calls,
36783708
)
3679-
message = _convert_to_v03_ai_message(message)
3709+
if output_version == "v0":
3710+
message = _convert_to_v03_ai_message(message)
3711+
else:
3712+
pass
36803713
return ChatResult(generations=[ChatGeneration(message=message)])
36813714

36823715

@@ -3688,6 +3721,7 @@ def _convert_responses_chunk_to_generation_chunk(
36883721
schema: Optional[type[_BM]] = None,
36893722
metadata: Optional[dict] = None,
36903723
has_reasoning: bool = False,
3724+
output_version: Literal["v0", "responses/v1"] = "v0",
36913725
) -> tuple[int, int, int, Optional[ChatGenerationChunk]]:
36923726
def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
36933727
"""Advance indexes tracked during streaming.
@@ -3756,12 +3790,15 @@ def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
37563790
elif chunk.type == "response.output_text.done":
37573791
content.append({"id": chunk.item_id, "index": current_index})
37583792
elif chunk.type == "response.created":
3759-
response_metadata["id"] = chunk.response.id
3793+
id = chunk.response.id
3794+
response_metadata["id"] = chunk.response.id # Backwards compatibility
37603795
elif chunk.type == "response.completed":
37613796
msg = cast(
37623797
AIMessage,
37633798
(
3764-
_construct_lc_result_from_responses_api(chunk.response, schema=schema)
3799+
_construct_lc_result_from_responses_api(
3800+
chunk.response, schema=schema, output_version=output_version
3801+
)
37653802
.generations[0]
37663803
.message
37673804
),
@@ -3773,7 +3810,10 @@ def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
37733810
k: v for k, v in msg.response_metadata.items() if k != "id"
37743811
}
37753812
elif chunk.type == "response.output_item.added" and chunk.item.type == "message":
3776-
id = chunk.item.id
3813+
if output_version == "v0":
3814+
id = chunk.item.id
3815+
else:
3816+
pass
37773817
elif (
37783818
chunk.type == "response.output_item.added"
37793819
and chunk.item.type == "function_call"
@@ -3868,9 +3908,13 @@ def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None:
38683908
additional_kwargs=additional_kwargs,
38693909
id=id,
38703910
)
3871-
message = cast(
3872-
AIMessageChunk, _convert_to_v03_ai_message(message, has_reasoning=has_reasoning)
3873-
)
3911+
if output_version == "v0":
3912+
message = cast(
3913+
AIMessageChunk,
3914+
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
3915+
)
3916+
else:
3917+
pass
38743918
return (
38753919
current_index,
38763920
current_output_index,
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import os
5-
from typing import Annotated, Any, Optional, cast
5+
from typing import Annotated, Any, Literal, Optional, cast
66

77
import openai
88
import pytest
@@ -50,15 +50,11 @@ def _check_response(response: Optional[BaseMessage]) -> None:
5050
assert response.usage_metadata["total_tokens"] > 0
5151
assert response.response_metadata["model_name"]
5252
assert response.response_metadata["service_tier"]
53-
for tool_output in response.additional_kwargs["tool_outputs"]:
54-
assert tool_output["id"]
55-
assert tool_output["status"]
56-
assert tool_output["type"]
5753

5854

5955
@pytest.mark.vcr
6056
def test_web_search() -> None:
61-
llm = ChatOpenAI(model=MODEL_NAME)
57+
llm = ChatOpenAI(model=MODEL_NAME, output_version="responses/v1")
6258
first_response = llm.invoke(
6359
"What was a positive news story from today?",
6460
tools=[{"type": "web_search_preview"}],
@@ -111,6 +107,11 @@ def test_web_search() -> None:
111107
)
112108
_check_response(response)
113109

110+
for msg in [first_response, full, response]:
111+
assert isinstance(msg, AIMessage)
112+
block_types = [block["type"] for block in msg.content] # type: ignore[index]
113+
assert block_types == ["web_search_call", "text"]
114+
114115

115116
@pytest.mark.flaky(retries=3, delay=1)
116117
async def test_web_search_async() -> None:
@@ -133,6 +134,12 @@ async def test_web_search_async() -> None:
133134
assert isinstance(full, AIMessageChunk)
134135
_check_response(full)
135136

137+
for msg in [response, full]:
138+
assert msg.additional_kwargs["tool_outputs"]
139+
assert len(msg.additional_kwargs["tool_outputs"]) == 1
140+
tool_output = msg.additional_kwargs["tool_outputs"][0]
141+
assert tool_output["type"] == "web_search_call"
142+
136143

137144
@pytest.mark.flaky(retries=3, delay=1)
138145
def test_function_calling() -> None:
@@ -288,20 +295,32 @@ def multiply(x: int, y: int) -> int:
288295
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
289296

290297

291-
def test_reasoning() -> None:
292-
llm = ChatOpenAI(model="o3-mini", use_responses_api=True)
298+
@pytest.mark.default_cassette("test_reasoning.yaml.gz")
299+
@pytest.mark.vcr
300+
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
301+
def test_reasoning(output_version: Literal["v0", "responses/v1"]) -> None:
302+
llm = ChatOpenAI(
303+
model="o4-mini", use_responses_api=True, output_version=output_version
304+
)
293305
response = llm.invoke("Hello", reasoning={"effort": "low"})
294306
assert isinstance(response, AIMessage)
295-
assert response.additional_kwargs["reasoning"]
296307

297308
# Test init params + streaming
298-
llm = ChatOpenAI(model="o3-mini", reasoning_effort="low", use_responses_api=True)
309+
llm = ChatOpenAI(
310+
model="o4-mini", reasoning={"effort": "low"}, output_version=output_version
311+
)
299312
full: Optional[BaseMessageChunk] = None
300313
for chunk in llm.stream("Hello"):
301314
assert isinstance(chunk, AIMessageChunk)
302315
full = chunk if full is None else full + chunk
303316
assert isinstance(full, AIMessage)
304-
assert full.additional_kwargs["reasoning"]
317+
318+
for msg in [response, full]:
319+
if output_version == "v0":
320+
assert msg.additional_kwargs["reasoning"]
321+
else:
322+
block_types = [block["type"] for block in msg.content]
323+
assert block_types == ["reasoning", "text"]
305324

306325

307326
def test_stateful_api() -> None:
@@ -355,20 +374,37 @@ def test_file_search() -> None:
355374
_check_response(full)
356375

357376

358-
def test_stream_reasoning_summary() -> None:
377+
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
378+
@pytest.mark.vcr
379+
@pytest.mark.parametrize("output_version", ["v0", "responses/v1"])
380+
def test_stream_reasoning_summary(
381+
output_version: Literal["v0", "responses/v1"],
382+
) -> None:
359383
llm = ChatOpenAI(
360384
model="o4-mini",
361385
# Routes to Responses API if `reasoning` is set.
362386
reasoning={"effort": "medium", "summary": "auto"},
387+
output_version=output_version,
363388
)
364-
message_1 = {"role": "user", "content": "What is 3^3?"}
389+
message_1 = {
390+
"role": "user",
391+
"content": "What was the third tallest buliding in the year 2000?",
392+
}
365393
response_1: Optional[BaseMessageChunk] = None
366394
for chunk in llm.stream([message_1]):
367395
assert isinstance(chunk, AIMessageChunk)
368396
response_1 = chunk if response_1 is None else response_1 + chunk
369397
assert isinstance(response_1, AIMessageChunk)
370-
reasoning = response_1.additional_kwargs["reasoning"]
371-
assert set(reasoning.keys()) == {"id", "type", "summary"}
398+
if output_version == "v0":
399+
reasoning = response_1.additional_kwargs["reasoning"]
400+
assert set(reasoning.keys()) == {"id", "type", "summary"}
401+
else:
402+
reasoning = next(
403+
block
404+
for block in response_1.content
405+
if block["type"] == "reasoning" # type: ignore[index]
406+
)
407+
assert set(reasoning.keys()) == {"id", "type", "summary", "index"}
372408
summary = reasoning["summary"]
373409
assert isinstance(summary, list)
374410
for block in summary:
@@ -462,11 +498,11 @@ def test_mcp_builtin() -> None:
462498
)
463499

464500

465-
@pytest.mark.skip
501+
@pytest.mark.vcr
466502
def test_mcp_builtin_zdr() -> None:
467503
llm = ChatOpenAI(
468504
model="o4-mini",
469-
use_responses_api=True,
505+
output_version="responses/v1",
470506
store=False,
471507
include=["reasoning.encrypted_content"],
472508
)

libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
}),
2525
'openai_api_type': 'azure',
2626
'openai_api_version': '2021-10-01',
27+
'output_version': 'v0',
2728
'request_timeout': 60.0,
2829
'stop': list([
2930
]),

libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
'lc': 1,
1919
'type': 'secret',
2020
}),
21+
'output_version': 'v0',
2122
'request_timeout': 60.0,
2223
'stop': list([
2324
]),

0 commit comments

Comments
 (0)