Skip to content

Commit 320200f

Browse files
Fix extended thinking block round-trip for LiteLLM assistant messages (#85)
Fixes #71 Related downstream issue: langchain-ai/deepagents#661
1 parent ac75516 commit 320200f

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

langchain_litellm/chat_models/litellm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,27 @@ def _create_retry_decorator(
102102
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
103103
)
104104

105+
def _inject_reasoning_content_into_content(
106+
content: Any, reasoning_content: str
107+
) -> List[Dict[str, Any]]:
108+
thinking_block = {"type": "thinking", "thinking": reasoning_content}
109+
if isinstance(content, list):
110+
has_thinking_block = any(
111+
isinstance(block, dict)
112+
and block.get("type") in ("thinking", "redacted_thinking")
113+
for block in content
114+
)
115+
if has_thinking_block:
116+
return content
117+
return [thinking_block, *content]
118+
119+
if not content:
120+
return [thinking_block]
121+
122+
if isinstance(content, str):
123+
return [thinking_block, {"type": "text", "text": content}]
124+
125+
return [thinking_block, content]
105126

106127
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
107128
role = _dict["role"]
@@ -152,6 +173,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
152173

153174
if _dict.get("reasoning_content"):
154175
additional_kwargs["reasoning_content"] = _dict["reasoning_content"]
176+
content = _inject_reasoning_content_into_content(
177+
content, _dict["reasoning_content"]
178+
)
155179

156180
# Check standard field first, then fallback to Vertex specific field
157181
provider_specific_fields = _dict.get("provider_specific_fields")
@@ -205,6 +229,9 @@ def _convert_delta_to_message_chunk(
205229
additional_kwargs["function_call"] = dict(function_call)
206230
if reasoning_content:
207231
additional_kwargs["reasoning_content"] = reasoning_content
232+
233+
if reasoning_content and (role == "assistant" or default_class == AIMessageChunk):
234+
content = _inject_reasoning_content_into_content(content, reasoning_content)
208235

209236
if provider_specific_fields is not None:
210237
additional_kwargs["provider_specific_fields"] = provider_specific_fields

tests/unit_tests/test_litellm.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from litellm.types.utils import ChatCompletionDeltaToolCall, Delta, Function
88

99
from langchain_litellm.chat_models import ChatLiteLLM
10-
from langchain_litellm.chat_models.litellm import _convert_delta_to_message_chunk, _convert_dict_to_message
10+
from langchain_litellm.chat_models.litellm import (
11+
_convert_delta_to_message_chunk,
12+
_convert_dict_to_message,
13+
_inject_reasoning_content_into_content,
14+
)
1115

1216

1317
class TestChatLiteLLMUnit(ChatModelUnitTests):
@@ -154,4 +158,41 @@ def test_provider_specific_fields_in_chat_result(self):
154158
result = llm._create_chat_result(mock_response)
155159

156160
assert "provider_specific_fields" in result.llm_output
157-
assert result.llm_output["provider_specific_fields"]["citations"][0]["source"] == "test"
161+
assert result.llm_output["provider_specific_fields"]["citations"][0]["source"] == "test"
162+
163+
164+
def test_inject_reasoning_content_into_string_content() -> None:
165+
result = _inject_reasoning_content_into_content("answer", "hidden chain")
166+
167+
assert result == [
168+
{"type": "thinking", "thinking": "hidden chain"},
169+
{"type": "text", "text": "answer"},
170+
]
171+
172+
173+
def test_inject_reasoning_content_into_empty_content() -> None:
174+
result = _inject_reasoning_content_into_content("", "hidden chain")
175+
176+
assert result == [{"type": "thinking", "thinking": "hidden chain"}]
177+
178+
179+
def test_inject_reasoning_content_prepends_for_list_without_thinking() -> None:
180+
content = [{"type": "text", "text": "answer"}]
181+
182+
result = _inject_reasoning_content_into_content(content, "hidden chain")
183+
184+
assert result == [
185+
{"type": "thinking", "thinking": "hidden chain"},
186+
{"type": "text", "text": "answer"},
187+
]
188+
189+
190+
def test_inject_reasoning_content_does_not_duplicate_existing_thinking() -> None:
191+
content = [
192+
{"type": "thinking", "thinking": "already there"},
193+
{"type": "text", "text": "answer"},
194+
]
195+
196+
result = _inject_reasoning_content_into_content(content, "hidden chain")
197+
198+
assert result == content

0 commit comments

Comments
 (0)