Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@
_convert_from_v1_to_responses,
_convert_to_v03_ai_message,
)
from langchain_openai.chat_models.reasoning_parser import (
extract_reasoning_content,
extract_reasoning_delta,
)

if TYPE_CHECKING:
from openai.types.responses import Response
Expand Down Expand Up @@ -1035,6 +1039,14 @@ def _convert_chunk_to_generation_chunk(
message_chunk.usage_metadata = usage_metadata

message_chunk.response_metadata["model_provider"] = "openai"

# Inject streaming reasoning delta
if choices := chunk.get("choices"):
delta = choices[0].get("delta", {})
reasoning_text = extract_reasoning_delta(self.model_name, delta)
if reasoning_text and isinstance(message_chunk, AIMessageChunk):
message_chunk.additional_kwargs["reasoning_content"] = reasoning_text

return ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
Expand Down Expand Up @@ -1416,6 +1428,12 @@ def _create_chat_result(
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal

# Inject model-specific reasoning content
reasoning_text = extract_reasoning_content(self.model_name, response_dict)
if reasoning_text:
generations[0].message.additional_kwargs["reasoning_content"] = (
reasoning_text
)
return ChatResult(generations=generations, llm_output=llm_output)

async def _astream(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: MIT
"""Utility functions.

Parsing non-standard reasoning or thinking fields
from OpenAI-compatible chat completion responses (e.g., Qwen models).
"""

from __future__ import annotations

from typing import Any


def extract_reasoning_content(
model_name: str, response_dict: dict[str, Any]
) -> str | None:
"""Extract 'reasoning_content' fields from an OpenAI-compatible response.

This function handles Qwen-family models that provide internal reasoning or
"think" traces in their message objects.
"""
if not isinstance(response_dict, dict):
return None

choices = response_dict.get("choices")
if not choices or not isinstance(choices, list):
return None

msg = choices[0].get("message")
if not isinstance(msg, dict):
return None

if "qwen" in (model_name or "").lower():
if "reasoning_content" in msg:
return msg["reasoning_content"]
for alt_key in ("think", "thought", "reasoning"):
if alt_key in msg:
return msg[alt_key]
return None


def extract_reasoning_delta(model_name: str, delta_dict: dict[str, Any]) -> str | None:
"""Extract reasoning field from incremental streaming deltas for Qwen models.

Used when consuming stream data (`choices[0].delta`)
to combine partial reasoning text.
"""
if not isinstance(delta_dict, dict):
return None

if "qwen" in (model_name or "").lower():
for key in ("reasoning_content", "think", "thought"):
if key in delta_dict:
return delta_dict[key]
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: MIT
"""Unit tests for langchain_openai.chat_models.reasoning_parser."""

import pytest

from langchain_openai.chat_models.reasoning_parser import (
extract_reasoning_content,
extract_reasoning_delta,
)


@pytest.mark.parametrize(
("model_name", "response_dict", "expected"),
[
# Standard Qwen with reasoning_content
(
"qwen3-chat",
{
"choices": [
{"message": {"content": "hi", "reasoning_content": "I am thinking"}}
]
},
"I am thinking",
),
# Qwen with alternative field names
(
"qwen2.5-instruct",
{"choices": [{"message": {"content": "hi", "think": "Another thought"}}]},
"Another thought",
),
(
"qwen-1.8",
{
"choices": [
{"message": {"content": "hi", "thought": "Internal reasoning"}}
]
},
"Internal reasoning",
),
# Non-Qwen model → should not extract anything
(
"gpt-4-turbo",
{
"choices": [
{"message": {"content": "hi", "reasoning_content": "ignore me"}}
]
},
None,
),
# Invalid structure: no choices
("qwen3-chat", {"message": {"content": "hi"}}, None),
# Invalid structure: message is not dict
("qwen3-chat", {"choices": [{"message": "not a dict"}]}, None),
# Empty / malformed response
("qwen3-chat", {}, None),
],
)
def test_extract_reasoning_content(
model_name: str, response_dict: dict, expected: str | None
) -> None:
"""Ensure reasoning extraction works correctly for various inputs."""
result = extract_reasoning_content(model_name, response_dict)
assert result == expected


@pytest.mark.parametrize(
("model_name", "delta_dict", "expected"),
[
# Qwen stream delta with reasoning_content
(
"qwen3-chat",
{"reasoning_content": "Streaming reasoning"},
"Streaming reasoning",
),
# Alternative field key
("qwen3-chat", {"think": "Stream thinking..."}, "Stream thinking..."),
# Unsupported model → None
("gpt-4o", {"reasoning_content": "should ignore"}, None),
# Malformed inputs
("qwen3-chat", {}, None),
("qwen3-chat", None, None),
],
)
def test_extract_reasoning_delta(
model_name: str, delta_dict: dict | None, expected: str | None
) -> None:
"""Ensure streaming delta reasoning extraction functions robustly."""
result = extract_reasoning_delta(model_name, delta_dict or {})
assert result == expected