diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 35beb08ec5f75..cb1f08f269096 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -123,6 +123,10 @@ _convert_from_v1_to_responses, _convert_to_v03_ai_message, ) +from langchain_openai.chat_models.reasoning_parser import ( + extract_reasoning_content, + extract_reasoning_delta, +) if TYPE_CHECKING: from openai.types.responses import Response @@ -1035,6 +1039,14 @@ def _convert_chunk_to_generation_chunk( message_chunk.usage_metadata = usage_metadata message_chunk.response_metadata["model_provider"] = "openai" + + # Inject streaming reasoning delta + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + reasoning_text = extract_reasoning_delta(self.model_name, delta) + if reasoning_text and isinstance(message_chunk, AIMessageChunk): + message_chunk.additional_kwargs["reasoning_content"] = reasoning_text + return ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None ) @@ -1416,6 +1428,12 @@ def _create_chat_result( if hasattr(message, "refusal"): generations[0].message.additional_kwargs["refusal"] = message.refusal + # Inject model-specific reasoning content + reasoning_text = extract_reasoning_content(self.model_name, response_dict) + if reasoning_text: + generations[0].message.additional_kwargs["reasoning_content"] = ( + reasoning_text + ) return ChatResult(generations=generations, llm_output=llm_output) async def _astream( diff --git a/libs/partners/openai/langchain_openai/chat_models/reasoning_parser.py b/libs/partners/openai/langchain_openai/chat_models/reasoning_parser.py new file mode 100644 index 0000000000000..b1309b683bd99 --- /dev/null +++ b/libs/partners/openai/langchain_openai/chat_models/reasoning_parser.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: MIT +"""Utility functions. + +Parsing non-standard reasoning or thinking fields +from OpenAI-compatible chat completion responses (e.g., Qwen models). +""" + +from __future__ import annotations + +from typing import Any + + +def extract_reasoning_content( + model_name: str, response_dict: dict[str, Any] +) -> str | None: + """Extract 'reasoning_content' fields from an OpenAI-compatible response. + + This function handles Qwen-family models that provide internal reasoning or + "think" traces in their message objects. + """ + if not isinstance(response_dict, dict): + return None + + choices = response_dict.get("choices") + if not choices or not isinstance(choices, list): + return None + + msg = choices[0].get("message") + if not isinstance(msg, dict): + return None + + if "qwen" in (model_name or "").lower(): + if "reasoning_content" in msg: + return msg["reasoning_content"] + for alt_key in ("think", "thought", "reasoning"): + if alt_key in msg: + return msg[alt_key] + return None + + +def extract_reasoning_delta(model_name: str, delta_dict: dict[str, Any]) -> str | None: + """Extract reasoning field from incremental streaming deltas for Qwen models. + + Used when consuming stream data (`choices[0].delta`) + to combine partial reasoning text. + """ + if not isinstance(delta_dict, dict): + return None + + if "qwen" in (model_name or "").lower(): + for key in ("reasoning_content", "think", "thought"): + if key in delta_dict: + return delta_dict[key] + return None diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_reasoning_parser.py b/libs/partners/openai/tests/unit_tests/chat_models/test_reasoning_parser.py new file mode 100644 index 0000000000000..10213e222ce7b --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_reasoning_parser.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: MIT +"""Unit tests for langchain_openai.chat_models.reasoning_parser.""" + +import pytest + +from langchain_openai.chat_models.reasoning_parser import ( + extract_reasoning_content, + extract_reasoning_delta, +) + + +@pytest.mark.parametrize( + ("model_name", "response_dict", "expected"), + [ + # Standard Qwen with reasoning_content + ( + "qwen3-chat", + { + "choices": [ + {"message": {"content": "hi", "reasoning_content": "I am thinking"}} + ] + }, + "I am thinking", + ), + # Qwen with alternative field names + ( + "qwen2.5-instruct", + {"choices": [{"message": {"content": "hi", "think": "Another thought"}}]}, + "Another thought", + ), + ( + "qwen-1.8", + { + "choices": [ + {"message": {"content": "hi", "thought": "Internal reasoning"}} + ] + }, + "Internal reasoning", + ), + # Non-Qwen model → should not extract anything + ( + "gpt-4-turbo", + { + "choices": [ + {"message": {"content": "hi", "reasoning_content": "ignore me"}} + ] + }, + None, + ), + # Invalid structure: no choices + ("qwen3-chat", {"message": {"content": "hi"}}, None), + # Invalid structure: message is not dict + ("qwen3-chat", {"choices": [{"message": "not a dict"}]}, None), + # Empty / malformed response + ("qwen3-chat", {}, None), + ], +) +def test_extract_reasoning_content( + model_name: str, response_dict: dict, expected: str | None +) -> None: + """Ensure reasoning extraction works correctly for various inputs.""" + result = extract_reasoning_content(model_name, response_dict) + assert result == expected + + +@pytest.mark.parametrize( + ("model_name", "delta_dict", "expected"), + [ + # Qwen stream delta with reasoning_content + ( + "qwen3-chat", + {"reasoning_content": "Streaming reasoning"}, + "Streaming reasoning", + ), + # Alternative field key + ("qwen3-chat", {"think": "Stream thinking..."}, "Stream thinking..."), + # Unsupported model → None + ("gpt-4o", {"reasoning_content": "should ignore"}, None), + # Malformed inputs + ("qwen3-chat", {}, None), + ("qwen3-chat", None, None), + ], +) +def test_extract_reasoning_delta( + model_name: str, delta_dict: dict | None, expected: str | None +) -> None: + """Ensure streaming delta reasoning extraction functions robustly.""" + result = extract_reasoning_delta(model_name, delta_dict or {}) + assert result == expected