Skip to content

Commit dcf5c7b

Browse files
authored
groq: add support for accessing reasoning output from Groq models (#31662)
**Description:** return [reasoning](https://console.groq.com/docs/reasoning) output in `additional_kwargs` as `reasoning_content` **Issue:** Resolves #31052
1 parent af2188b commit dcf5c7b

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ class ChatGroq(BaseChatModel):
107107
Sampling temperature. Ranges from 0.0 to 1.0.
108108
max_tokens: Optional[int]
109109
Max number of tokens to generate.
110+
reasoning_format: Optional[Literal["parsed", "raw", "hidden]]
111+
The format for reasoning output.
112+
113+
- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
114+
- ``raw``: Includes reasoning within think tags in the content.
115+
- ``hidden``: Returns only the final answer.
110116
model_kwargs: Dict[str, Any]
111117
Holds any model parameters valid for create call not
112118
explicitly specified.
@@ -292,7 +298,7 @@ class Joke(BaseModel):
292298
'system_fingerprint': 'fp_c5f20b5bb1',
293299
'finish_reason': 'stop',
294300
'logprobs': None}
295-
"""
301+
""" # noqa: E501
296302

297303
client: Any = Field(default=None, exclude=True) #: :meta private:
298304
async_client: Any = Field(default=None, exclude=True) #: :meta private:
@@ -302,6 +308,13 @@ class Joke(BaseModel):
302308
"""What sampling temperature to use."""
303309
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
304310
"""Default stop sequences."""
311+
reasoning_format: Optional[Literal["parsed", "raw", "hidden"]] = None
312+
"""The format for reasoning output.
313+
314+
- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
315+
- ``raw``: Includes reasoning within think tags in the content.
316+
- ``hidden``: Returns only the final answer.
317+
""" # noqa: E501
305318
model_kwargs: dict[str, Any] = Field(default_factory=dict)
306319
"""Holds any model parameters valid for `create` call not explicitly specified."""
307320
groq_api_key: Optional[SecretStr] = Field(
@@ -606,6 +619,7 @@ def _default_params(self) -> dict[str, Any]:
606619
"n": self.n,
607620
"temperature": self.temperature,
608621
"stop": self.stop,
622+
"reasoning_format": self.reasoning_format,
609623
**self.model_kwargs,
610624
}
611625
if self.max_tokens is not None:
@@ -1153,6 +1167,8 @@ def _convert_chunk_to_message_chunk(
11531167
if role == "user" or default_class == HumanMessageChunk:
11541168
return HumanMessageChunk(content=content)
11551169
elif role == "assistant" or default_class == AIMessageChunk:
1170+
if reasoning := _dict.get("reasoning"):
1171+
additional_kwargs["reasoning_content"] = reasoning
11561172
if usage := (chunk.get("x_groq") or {}).get("usage"):
11571173
input_tokens = usage.get("prompt_tokens", 0)
11581174
output_tokens = usage.get("completion_tokens", 0)
@@ -1196,6 +1212,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
11961212
elif role == "assistant":
11971213
content = _dict.get("content", "") or ""
11981214
additional_kwargs: dict = {}
1215+
if reasoning := _dict.get("reasoning"):
1216+
additional_kwargs["reasoning_content"] = reasoning
11991217
if function_call := _dict.get("function_call"):
12001218
additional_kwargs["function_call"] = dict(function_call)
12011219
tool_calls = []

libs/partners/groq/tests/integration_tests/test_chat_models.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test ChatGroq chat model."""
22

33
import json
4-
from typing import Any, Optional
4+
from typing import Any, Optional, cast
55

66
import pytest
77
from langchain_core.messages import (
@@ -212,6 +212,58 @@ async def test_agenerate_streaming() -> None:
212212
assert generation.text == generation.message.content
213213

214214

215+
#
216+
# Test reasoning output
217+
#
218+
def test_reasoning_output_invoke() -> None:
219+
"""Test reasoning output from ChatGroq with invoke."""
220+
chat = ChatGroq(
221+
model="deepseek-r1-distill-llama-70b",
222+
reasoning_format="parsed",
223+
)
224+
message = [
225+
SystemMessage(
226+
content="You are a helpful assistant that translates English to French."
227+
),
228+
HumanMessage(content="I love programming."),
229+
]
230+
response = chat.invoke(message)
231+
assert isinstance(response, AIMessage)
232+
assert "reasoning_content" in response.additional_kwargs
233+
assert isinstance(response.additional_kwargs["reasoning_content"], str)
234+
assert len(response.additional_kwargs["reasoning_content"]) > 0
235+
236+
237+
def test_reasoning_output_stream() -> None:
238+
"""Test reasoning output from ChatGroq with stream."""
239+
chat = ChatGroq(
240+
model="deepseek-r1-distill-llama-70b",
241+
reasoning_format="parsed",
242+
)
243+
message = [
244+
SystemMessage(
245+
content="You are a helpful assistant that translates English to French."
246+
),
247+
HumanMessage(content="I love programming."),
248+
]
249+
250+
full_response: Optional[AIMessageChunk] = None
251+
for token in chat.stream(message):
252+
assert isinstance(token, AIMessageChunk)
253+
254+
if full_response is None:
255+
full_response = token
256+
else:
257+
# Casting since adding results in a type error
258+
full_response = cast(AIMessageChunk, full_response + token)
259+
260+
assert full_response is not None
261+
assert isinstance(full_response, AIMessageChunk)
262+
assert "reasoning_content" in full_response.additional_kwargs
263+
assert isinstance(full_response.additional_kwargs["reasoning_content"], str)
264+
assert len(full_response.additional_kwargs["reasoning_content"]) > 0
265+
266+
215267
#
216268
# Misc tests
217269
#

0 commit comments

Comments
 (0)