From 816f898379bfe1ca4e91e551f1b7bdfd3faae04a Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Tue, 28 Oct 2025 18:38:35 +0300 Subject: [PATCH 1/6] feat(groq): add prompt caching token usage details --- .../groq/langchain_groq/chat_models.py | 62 ++- .../groq/tests/unit_tests/test_chat_models.py | 358 +++++++++++++++++- 2 files changed, 403 insertions(+), 17 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index bbbe94dad06c4..69234be8b4558 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -37,6 +37,11 @@ ToolMessage, ToolMessageChunk, ) +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, +) from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -704,15 +709,7 @@ def _create_chat_result( for res in response["choices"]: message = _convert_dict_to_message(res["message"]) if token_usage and isinstance(message, AIMessage): - input_tokens = token_usage.get("prompt_tokens", 0) - output_tokens = token_usage.get("completion_tokens", 0) - message.usage_metadata = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": token_usage.get( - "total_tokens", input_tokens + output_tokens - ), - } + message.usage_metadata = _create_usage_metadata(token_usage) generation_info = {"finish_reason": res.get("finish_reason")} if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] @@ -1303,13 +1300,7 @@ def _convert_chunk_to_message_chunk( {k: executed_tool[k] for k in executed_tool if k != "output"} ) if usage := (chunk.get("x_groq") or {}).get("usage"): - input_tokens = usage.get("prompt_tokens", 0) - output_tokens = usage.get("completion_tokens", 0) - usage_metadata = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": usage.get("total_tokens", input_tokens + output_tokens), - } + usage_metadata = _create_usage_metadata(usage) else: usage_metadata = None return AIMessageChunk( @@ -1409,3 +1400,42 @@ def _lc_invalid_tool_call_to_groq_tool_call( "arguments": invalid_tool_call["args"], }, } + +def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: + """Create usage metadata from Groq token usage response. + + Args: + groq_token_usage: Token usage dict from Groq API response. + + Returns: + Usage metadata dict with input/output token details. + """ + input_tokens = groq_token_usage.get("prompt_tokens", 0) + output_tokens = groq_token_usage.get("completion_tokens", 0) + total_tokens = groq_token_usage.get("total_tokens", input_tokens + output_tokens) + + usage_metadata: UsageMetadata = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + + if prompt_details := groq_token_usage.get("prompt_tokens_details"): + input_token_details: InputTokenDetails = {} + if cached := prompt_details.get("cached_tokens"): + input_token_details["cache_read"] = cached + if input_token_details: + usage_metadata["input_token_details"] = input_token_details + + if completion_details := groq_token_usage.get("completion_tokens_details"): + output_token_details: OutputTokenDetails = {} + if reasoning := completion_details.get("reasoning_tokens"): + output_token_details["reasoning"] = reasoning + if accepted := completion_details.get("accepted_prediction_tokens"): + output_token_details["accepted_prediction"] = accepted + if rejected := completion_details.get("rejected_prediction_tokens"): + output_token_details["rejected_prediction"] = rejected + if output_token_details: + usage_metadata["output_token_details"] = output_token_details + + return usage_metadata diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index bad3474cef9a3..3268835ce600f 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -9,6 +9,7 @@ import pytest from langchain_core.messages import ( AIMessage, + AIMessageChunk, FunctionMessage, HumanMessage, InvalidToolCall, @@ -16,7 +17,12 @@ ToolCall, ) -from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message +from langchain_groq.chat_models import ( + ChatGroq, + _convert_chunk_to_message_chunk, + _convert_dict_to_message, + _create_usage_metadata, +) if "GROQ_API_KEY" not in os.environ: os.environ["GROQ_API_KEY"] = "fake-key" @@ -283,3 +289,353 @@ def test_groq_serialization() -> None: # Ensure a None was preserved assert llm.groq_api_base == llm2.groq_api_base + + +def test_create_usage_metadata_basic() -> None: + """Test basic usage metadata creation without details.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + assert "input_token_details" not in result + assert "output_token_details" not in result + + +def test_create_usage_metadata_with_cached_tokens() -> None: + """Test usage metadata with prompt caching.""" + token_usage = { + "prompt_tokens": 2006, + "completion_tokens": 300, + "total_tokens": 2306, + "prompt_tokens_details": { + "cached_tokens": 1920 + } + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 2006 + assert result["output_tokens"] == 300 + assert result["total_tokens"] == 2306 + assert "input_token_details" in result + assert isinstance(result["input_token_details"], dict) + assert result["input_token_details"]["cache_read"] == 1920 + assert "output_token_details" not in result + + +def test_create_usage_metadata_with_reasoning_tokens() -> None: + """Test usage metadata with reasoning tokens.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 150, + "total_tokens": 250, + "completion_tokens_details": { + "reasoning_tokens": 50 + } + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 150 + assert result["total_tokens"] == 250 + assert "input_token_details" not in result + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert result["output_token_details"]["reasoning"] == 50 + + +def test_create_usage_metadata_with_prediction_tokens() -> None: + """Test usage metadata with prediction tokens.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 200, + "total_tokens": 300, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 50, + "rejected_prediction_tokens": 10 + } + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 200 + assert result["total_tokens"] == 300 + assert "input_token_details" not in result + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert "reasoning" not in result["output_token_details"] + assert result["output_token_details"]["accepted_prediction"] == 50 + assert result["output_token_details"]["rejected_prediction"] == 10 + + +def test_create_usage_metadata_with_all_details() -> None: + """Test usage metadata with all token details.""" + token_usage = { + "prompt_tokens": 2006, + "completion_tokens": 300, + "total_tokens": 2306, + "prompt_tokens_details": { + "cached_tokens": 1920 + }, + "completion_tokens_details": { + "reasoning_tokens": 50, + "accepted_prediction_tokens": 20, + "rejected_prediction_tokens": 5 + } + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 2006 + assert result["output_tokens"] == 300 + assert result["total_tokens"] == 2306 + + assert "input_token_details" in result + assert isinstance(result["input_token_details"], dict) + assert result["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert result["output_token_details"]["reasoning"] == 50 + assert result["output_token_details"]["accepted_prediction"] == 20 + assert result["output_token_details"]["rejected_prediction"] == 5 + + +def test_create_usage_metadata_missing_total_tokens() -> None: + """Test that total_tokens is calculated when missing.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + } + + result = _create_usage_metadata(token_usage) + + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + + +def test_create_usage_metadata_empty_details() -> None: + """Test that empty detail dicts don't create token detail objects.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_tokens_details": {}, + "completion_tokens_details": {} + } + + result = _create_usage_metadata(token_usage) + + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + assert "input_token_details" not in result + assert "output_token_details" not in result + + +def test_create_usage_metadata_zero_cached_tokens() -> None: + """Test that zero cached tokens are not included (falsy).""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_tokens_details": { + "cached_tokens": 0 + } + } + + result = _create_usage_metadata(token_usage) + + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + assert "input_token_details" not in result + + +def test_chat_result_with_usage_metadata() -> None: + """Test that _create_chat_result properly includes usage metadata.""" + llm = ChatGroq(model="test-model") + + mock_response = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Test response", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 2006, + "completion_tokens": 300, + "total_tokens": 2306, + "prompt_tokens_details": { + "cached_tokens": 1920 + }, + "completion_tokens_details": { + "reasoning_tokens": 50 + } + } + } + + result = llm._create_chat_result(mock_response, {}) + + assert len(result.generations) == 1 + message = result.generations[0].message + assert isinstance(message, AIMessage) + assert message.content == "Test response" + + assert message.usage_metadata is not None + assert isinstance(message.usage_metadata, dict) + assert message.usage_metadata["input_tokens"] == 2006 + assert message.usage_metadata["output_tokens"] == 300 + assert message.usage_metadata["total_tokens"] == 2306 + + assert "input_token_details" in message.usage_metadata + assert message.usage_metadata["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in message.usage_metadata + assert message.usage_metadata["output_token_details"]["reasoning"] == 50 + + +def test_chat_result_backward_compatibility() -> None: + """Test that responses without new fields still work.""" + llm = ChatGroq(model="test-model") + + mock_response = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Test response", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + } + + result = llm._create_chat_result(mock_response, {}) + + assert len(result.generations) == 1 + message = result.generations[0].message + assert isinstance(message, AIMessage) + + assert message.usage_metadata is not None + assert message.usage_metadata["input_tokens"] == 100 + assert message.usage_metadata["output_tokens"] == 50 + assert message.usage_metadata["total_tokens"] == 150 + + assert "input_token_details" not in message.usage_metadata + assert "output_token_details" not in message.usage_metadata + + +def test_streaming_with_usage_metadata() -> None: + """Test that streaming properly includes usage metadata.""" + chunk = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": "Hello", + }, + "finish_reason": None, + } + ], + "x_groq": { + "usage": { + "prompt_tokens": 2006, + "completion_tokens": 300, + "total_tokens": 2306, + "prompt_tokens_details": { + "cached_tokens": 1920 + }, + "completion_tokens_details": { + "reasoning_tokens": 50, + "accepted_prediction_tokens": 20, + "rejected_prediction_tokens": 5 + } + } + } + } + + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + + assert isinstance(result, AIMessageChunk) + assert result.content == "Hello" + + assert result.usage_metadata is not None + assert isinstance(result.usage_metadata, dict) + assert result.usage_metadata["input_tokens"] == 2006 + assert result.usage_metadata["output_tokens"] == 300 + assert result.usage_metadata["total_tokens"] == 2306 + + assert "input_token_details" in result.usage_metadata + assert result.usage_metadata["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in result.usage_metadata + assert result.usage_metadata["output_token_details"]["reasoning"] == 50 + assert result.usage_metadata["output_token_details"]["accepted_prediction"] == 20 + assert result.usage_metadata["output_token_details"]["rejected_prediction"] == 5 + + +def test_streaming_without_usage_metadata() -> None: + """Test that streaming works without usage metadata (backward compatibility).""" + chunk = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": "Hello", + }, + "finish_reason": None, + } + ], + } + + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + + assert isinstance(result, AIMessageChunk) + assert result.content == "Hello" + assert result.usage_metadata is None From a89e2e2d1161dbceccbb92dda6abd8a850db254b Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Tue, 28 Oct 2025 21:42:12 +0300 Subject: [PATCH 2/6] chore: fix linting issue --- .../groq/langchain_groq/chat_models.py | 40 ++++++------ .../groq/tests/unit_tests/test_chat_models.py | 64 +++++++------------ 2 files changed, 41 insertions(+), 63 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 69234be8b4558..59ae62ddc575a 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -1401,6 +1401,7 @@ def _lc_invalid_tool_call_to_groq_tool_call( }, } + def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: """Create usage metadata from Groq token usage response. @@ -1410,32 +1411,27 @@ def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: Returns: Usage metadata dict with input/output token details. """ - input_tokens = groq_token_usage.get("prompt_tokens", 0) - output_tokens = groq_token_usage.get("completion_tokens", 0) - total_tokens = groq_token_usage.get("total_tokens", input_tokens + output_tokens) - + input_tokens = groq_token_usage.get("prompt_tokens") or 0 + output_tokens = groq_token_usage.get("completion_tokens") or 0 + total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens + input_token_details: dict = { + "cache_read": (groq_token_usage.get("prompt_tokens_details") or {}).get( + "cached_tokens" + ), + } + output_token_details: dict = { + "reasoning": (groq_token_usage.get("completion_tokens_details") or {}).get( + "reasoning_tokens" + ), + } usage_metadata: UsageMetadata = { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, } - if prompt_details := groq_token_usage.get("prompt_tokens_details"): - input_token_details: InputTokenDetails = {} - if cached := prompt_details.get("cached_tokens"): - input_token_details["cache_read"] = cached - if input_token_details: - usage_metadata["input_token_details"] = input_token_details - - if completion_details := groq_token_usage.get("completion_tokens_details"): - output_token_details: OutputTokenDetails = {} - if reasoning := completion_details.get("reasoning_tokens"): - output_token_details["reasoning"] = reasoning - if accepted := completion_details.get("accepted_prediction_tokens"): - output_token_details["accepted_prediction"] = accepted - if rejected := completion_details.get("rejected_prediction_tokens"): - output_token_details["rejected_prediction"] = rejected - if output_token_details: - usage_metadata["output_token_details"] = output_token_details - + if filtered_input := {k: v for k, v in input_token_details.items() if v}: + usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] + if filtered_output := {k: v for k, v in output_token_details.items() if v}: + usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] return usage_metadata diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 3268835ce600f..8a6c43d32f4da 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -315,9 +315,7 @@ def test_create_usage_metadata_with_cached_tokens() -> None: "prompt_tokens": 2006, "completion_tokens": 300, "total_tokens": 2306, - "prompt_tokens_details": { - "cached_tokens": 1920 - } + "prompt_tokens_details": {"cached_tokens": 1920}, } result = _create_usage_metadata(token_usage) @@ -338,9 +336,7 @@ def test_create_usage_metadata_with_reasoning_tokens() -> None: "prompt_tokens": 100, "completion_tokens": 150, "total_tokens": 250, - "completion_tokens_details": { - "reasoning_tokens": 50 - } + "completion_tokens_details": {"reasoning_tokens": 50}, } result = _create_usage_metadata(token_usage) @@ -356,7 +352,11 @@ def test_create_usage_metadata_with_reasoning_tokens() -> None: def test_create_usage_metadata_with_prediction_tokens() -> None: - """Test usage metadata with prediction tokens.""" + """Test usage metadata with prediction tokens. + + Note: Prediction tokens are not included in usage_metadata. + Users can access them from response_metadata["token_usage"] if needed. + """ token_usage = { "prompt_tokens": 100, "completion_tokens": 200, @@ -364,8 +364,8 @@ def test_create_usage_metadata_with_prediction_tokens() -> None: "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 50, - "rejected_prediction_tokens": 10 - } + "rejected_prediction_tokens": 10, + }, } result = _create_usage_metadata(token_usage) @@ -375,11 +375,7 @@ def test_create_usage_metadata_with_prediction_tokens() -> None: assert result["output_tokens"] == 200 assert result["total_tokens"] == 300 assert "input_token_details" not in result - assert "output_token_details" in result - assert isinstance(result["output_token_details"], dict) - assert "reasoning" not in result["output_token_details"] - assert result["output_token_details"]["accepted_prediction"] == 50 - assert result["output_token_details"]["rejected_prediction"] == 10 + assert "output_token_details" not in result def test_create_usage_metadata_with_all_details() -> None: @@ -388,14 +384,12 @@ def test_create_usage_metadata_with_all_details() -> None: "prompt_tokens": 2006, "completion_tokens": 300, "total_tokens": 2306, - "prompt_tokens_details": { - "cached_tokens": 1920 - }, + "prompt_tokens_details": {"cached_tokens": 1920}, "completion_tokens_details": { "reasoning_tokens": 50, "accepted_prediction_tokens": 20, - "rejected_prediction_tokens": 5 - } + "rejected_prediction_tokens": 5, + }, } result = _create_usage_metadata(token_usage) @@ -412,8 +406,6 @@ def test_create_usage_metadata_with_all_details() -> None: assert "output_token_details" in result assert isinstance(result["output_token_details"], dict) assert result["output_token_details"]["reasoning"] == 50 - assert result["output_token_details"]["accepted_prediction"] == 20 - assert result["output_token_details"]["rejected_prediction"] == 5 def test_create_usage_metadata_missing_total_tokens() -> None: @@ -437,7 +429,7 @@ def test_create_usage_metadata_empty_details() -> None: "completion_tokens": 50, "total_tokens": 150, "prompt_tokens_details": {}, - "completion_tokens_details": {} + "completion_tokens_details": {}, } result = _create_usage_metadata(token_usage) @@ -455,9 +447,7 @@ def test_create_usage_metadata_zero_cached_tokens() -> None: "prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150, - "prompt_tokens_details": { - "cached_tokens": 0 - } + "prompt_tokens_details": {"cached_tokens": 0}, } result = _create_usage_metadata(token_usage) @@ -491,13 +481,9 @@ def test_chat_result_with_usage_metadata() -> None: "prompt_tokens": 2006, "completion_tokens": 300, "total_tokens": 2306, - "prompt_tokens_details": { - "cached_tokens": 1920 - }, - "completion_tokens_details": { - "reasoning_tokens": 50 - } - } + "prompt_tokens_details": {"cached_tokens": 1920}, + "completion_tokens_details": {"reasoning_tokens": 50}, + }, } result = llm._create_chat_result(mock_response, {}) @@ -543,7 +529,7 @@ def test_chat_result_backward_compatibility() -> None: "prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150, - } + }, } result = llm._create_chat_result(mock_response, {}) @@ -583,16 +569,14 @@ def test_streaming_with_usage_metadata() -> None: "prompt_tokens": 2006, "completion_tokens": 300, "total_tokens": 2306, - "prompt_tokens_details": { - "cached_tokens": 1920 - }, + "prompt_tokens_details": {"cached_tokens": 1920}, "completion_tokens_details": { "reasoning_tokens": 50, "accepted_prediction_tokens": 20, - "rejected_prediction_tokens": 5 - } + "rejected_prediction_tokens": 5, + }, } - } + }, } result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) @@ -611,8 +595,6 @@ def test_streaming_with_usage_metadata() -> None: assert "output_token_details" in result.usage_metadata assert result.usage_metadata["output_token_details"]["reasoning"] == 50 - assert result.usage_metadata["output_token_details"]["accepted_prediction"] == 20 - assert result.usage_metadata["output_token_details"]["rejected_prediction"] == 5 def test_streaming_without_usage_metadata() -> None: From 2aa122be1dea69a5071978d87688d421dfc1cb00 Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Tue, 28 Oct 2025 22:07:56 +0300 Subject: [PATCH 3/6] chore: fix test cases --- .../groq/langchain_groq/chat_models.py | 8 --- .../groq/tests/unit_tests/test_chat_models.py | 68 ++----------------- 2 files changed, 4 insertions(+), 72 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 59ae62ddc575a..74265aef0e8ca 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -39,7 +39,6 @@ ) from langchain_core.messages.ai import ( InputTokenDetails, - OutputTokenDetails, UsageMetadata, ) from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser @@ -1419,11 +1418,6 @@ def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: "cached_tokens" ), } - output_token_details: dict = { - "reasoning": (groq_token_usage.get("completion_tokens_details") or {}).get( - "reasoning_tokens" - ), - } usage_metadata: UsageMetadata = { "input_tokens": input_tokens, "output_tokens": output_tokens, @@ -1432,6 +1426,4 @@ def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: if filtered_input := {k: v for k, v in input_token_details.items() if v}: usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] - if filtered_output := {k: v for k, v in output_token_details.items() if v}: - usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] return usage_metadata diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 8a6c43d32f4da..58df10ccf82a5 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -330,66 +330,17 @@ def test_create_usage_metadata_with_cached_tokens() -> None: assert "output_token_details" not in result -def test_create_usage_metadata_with_reasoning_tokens() -> None: - """Test usage metadata with reasoning tokens.""" - token_usage = { - "prompt_tokens": 100, - "completion_tokens": 150, - "total_tokens": 250, - "completion_tokens_details": {"reasoning_tokens": 50}, - } - result = _create_usage_metadata(token_usage) - assert isinstance(result, dict) - assert result["input_tokens"] == 100 - assert result["output_tokens"] == 150 - assert result["total_tokens"] == 250 - assert "input_token_details" not in result - assert "output_token_details" in result - assert isinstance(result["output_token_details"], dict) - assert result["output_token_details"]["reasoning"] == 50 - - -def test_create_usage_metadata_with_prediction_tokens() -> None: - """Test usage metadata with prediction tokens. - - Note: Prediction tokens are not included in usage_metadata. - Users can access them from response_metadata["token_usage"] if needed. - """ - token_usage = { - "prompt_tokens": 100, - "completion_tokens": 200, - "total_tokens": 300, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 50, - "rejected_prediction_tokens": 10, - }, - } - - result = _create_usage_metadata(token_usage) - - assert isinstance(result, dict) - assert result["input_tokens"] == 100 - assert result["output_tokens"] == 200 - assert result["total_tokens"] == 300 - assert "input_token_details" not in result - assert "output_token_details" not in result def test_create_usage_metadata_with_all_details() -> None: - """Test usage metadata with all token details.""" + """Test usage metadata with all available details.""" token_usage = { "prompt_tokens": 2006, "completion_tokens": 300, "total_tokens": 2306, "prompt_tokens_details": {"cached_tokens": 1920}, - "completion_tokens_details": { - "reasoning_tokens": 50, - "accepted_prediction_tokens": 20, - "rejected_prediction_tokens": 5, - }, } result = _create_usage_metadata(token_usage) @@ -403,9 +354,7 @@ def test_create_usage_metadata_with_all_details() -> None: assert isinstance(result["input_token_details"], dict) assert result["input_token_details"]["cache_read"] == 1920 - assert "output_token_details" in result - assert isinstance(result["output_token_details"], dict) - assert result["output_token_details"]["reasoning"] == 50 + assert "output_token_details" not in result def test_create_usage_metadata_missing_total_tokens() -> None: @@ -429,7 +378,6 @@ def test_create_usage_metadata_empty_details() -> None: "completion_tokens": 50, "total_tokens": 150, "prompt_tokens_details": {}, - "completion_tokens_details": {}, } result = _create_usage_metadata(token_usage) @@ -482,7 +430,6 @@ def test_chat_result_with_usage_metadata() -> None: "completion_tokens": 300, "total_tokens": 2306, "prompt_tokens_details": {"cached_tokens": 1920}, - "completion_tokens_details": {"reasoning_tokens": 50}, }, } @@ -502,8 +449,7 @@ def test_chat_result_with_usage_metadata() -> None: assert "input_token_details" in message.usage_metadata assert message.usage_metadata["input_token_details"]["cache_read"] == 1920 - assert "output_token_details" in message.usage_metadata - assert message.usage_metadata["output_token_details"]["reasoning"] == 50 + assert "output_token_details" not in message.usage_metadata def test_chat_result_backward_compatibility() -> None: @@ -570,11 +516,6 @@ def test_streaming_with_usage_metadata() -> None: "completion_tokens": 300, "total_tokens": 2306, "prompt_tokens_details": {"cached_tokens": 1920}, - "completion_tokens_details": { - "reasoning_tokens": 50, - "accepted_prediction_tokens": 20, - "rejected_prediction_tokens": 5, - }, } }, } @@ -593,8 +534,7 @@ def test_streaming_with_usage_metadata() -> None: assert "input_token_details" in result.usage_metadata assert result.usage_metadata["input_token_details"]["cache_read"] == 1920 - assert "output_token_details" in result.usage_metadata - assert result.usage_metadata["output_token_details"]["reasoning"] == 50 + assert "output_token_details" not in result.usage_metadata def test_streaming_without_usage_metadata() -> None: From 8f189ea45ce8282b586e3ba01aca8563e1580fd4 Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Tue, 28 Oct 2025 22:17:52 +0300 Subject: [PATCH 4/6] chore: format code --- libs/partners/groq/tests/unit_tests/test_chat_models.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 58df10ccf82a5..458f853f4b262 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -330,10 +330,6 @@ def test_create_usage_metadata_with_cached_tokens() -> None: assert "output_token_details" not in result - - - - def test_create_usage_metadata_with_all_details() -> None: """Test usage metadata with all available details.""" token_usage = { From 469e62d583bd82ceb245b4a80ba9ff4373b1d395 Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Tue, 4 Nov 2025 12:13:06 +0300 Subject: [PATCH 5/6] feat(groq): add reasoning tokens support with tests --- .../groq/langchain_groq/chat_models.py | 23 +- .../groq/tests/unit_tests/test_chat_models.py | 338 ++++++++++++++++++ 2 files changed, 360 insertions(+), 1 deletion(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 74265aef0e8ca..a1cce0431cc60 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -39,6 +39,7 @@ ) from langchain_core.messages.ai import ( InputTokenDetails, + OutputTokenDetails, UsageMetadata, ) from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser @@ -748,7 +749,20 @@ def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: if token_usage is not None: for k, v in token_usage.items(): if k in overall_token_usage and v is not None: - overall_token_usage[k] += v + # Handle nested dictionaries + if isinstance(v, dict): + if k not in overall_token_usage: + overall_token_usage[k] = {} + for nested_k, nested_v in v.items(): + if ( + nested_k in overall_token_usage[k] + and nested_v is not None + ): + overall_token_usage[k][nested_k] += nested_v + else: + overall_token_usage[k][nested_k] = nested_v + else: + overall_token_usage[k] += v else: overall_token_usage[k] = v if system_fingerprint is None: @@ -1418,6 +1432,11 @@ def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: "cached_tokens" ), } + output_token_details: dict = { + "reasoning": (groq_token_usage.get("completion_tokens_details") or {}).get( + "reasoning_tokens" + ), + } usage_metadata: UsageMetadata = { "input_tokens": input_tokens, "output_tokens": output_tokens, @@ -1426,4 +1445,6 @@ def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: if filtered_input := {k: v for k, v in input_token_details.items() if v}: usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] + if filtered_output := {k: v for k, v in output_token_details.items() if v}: + usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] return usage_metadata diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 458f853f4b262..607b96fe65201 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -402,6 +402,87 @@ def test_create_usage_metadata_zero_cached_tokens() -> None: assert "input_token_details" not in result +def test_create_usage_metadata_with_reasoning_tokens() -> None: + """Test usage metadata with reasoning tokens.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 450, + "total_tokens": 550, + "completion_tokens_details": {"reasoning_tokens": 200}, + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 450 + assert result["total_tokens"] == 550 + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert result["output_token_details"]["reasoning"] == 200 + assert "input_token_details" not in result + + +def test_create_usage_metadata_with_cached_and_reasoning_tokens() -> None: + """Test usage metadata with both cached and reasoning tokens.""" + token_usage = { + "prompt_tokens": 2006, + "completion_tokens": 450, + "total_tokens": 2456, + "prompt_tokens_details": {"cached_tokens": 1920}, + "completion_tokens_details": {"reasoning_tokens": 200}, + } + + result = _create_usage_metadata(token_usage) + + assert isinstance(result, dict) + assert result["input_tokens"] == 2006 + assert result["output_tokens"] == 450 + assert result["total_tokens"] == 2456 + + assert "input_token_details" in result + assert isinstance(result["input_token_details"], dict) + assert result["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert result["output_token_details"]["reasoning"] == 200 + + +def test_create_usage_metadata_zero_reasoning_tokens() -> None: + """Test that zero reasoning tokens are not included (falsy).""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": {"reasoning_tokens": 0}, + } + + result = _create_usage_metadata(token_usage) + + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + assert "output_token_details" not in result + + +def test_create_usage_metadata_empty_completion_details() -> None: + """Test that empty completion_tokens_details don't create output_token_details.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": {}, + } + + result = _create_usage_metadata(token_usage) + + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["total_tokens"] == 150 + assert "output_token_details" not in result + + def test_chat_result_with_usage_metadata() -> None: """Test that _create_chat_result properly includes usage metadata.""" llm = ChatGroq(model="test-model") @@ -448,6 +529,100 @@ def test_chat_result_with_usage_metadata() -> None: assert "output_token_details" not in message.usage_metadata +def test_chat_result_with_reasoning_tokens() -> None: + """Test that _create_chat_result properly includes reasoning tokens.""" + llm = ChatGroq(model="test-model") + + mock_response = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Test reasoning response", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 450, + "total_tokens": 550, + "completion_tokens_details": {"reasoning_tokens": 200}, + }, + } + + result = llm._create_chat_result(mock_response, {}) + + assert len(result.generations) == 1 + message = result.generations[0].message + assert isinstance(message, AIMessage) + assert message.content == "Test reasoning response" + + assert message.usage_metadata is not None + assert isinstance(message.usage_metadata, dict) + assert message.usage_metadata["input_tokens"] == 100 + assert message.usage_metadata["output_tokens"] == 450 + assert message.usage_metadata["total_tokens"] == 550 + + assert "output_token_details" in message.usage_metadata + assert message.usage_metadata["output_token_details"]["reasoning"] == 200 + + assert "input_token_details" not in message.usage_metadata + + +def test_chat_result_with_cached_and_reasoning_tokens() -> None: + """Test that _create_chat_result includes both cached and reasoning tokens.""" + llm = ChatGroq(model="test-model") + + mock_response = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Test response with both", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 2006, + "completion_tokens": 450, + "total_tokens": 2456, + "prompt_tokens_details": {"cached_tokens": 1920}, + "completion_tokens_details": {"reasoning_tokens": 200}, + }, + } + + result = llm._create_chat_result(mock_response, {}) + + assert len(result.generations) == 1 + message = result.generations[0].message + assert isinstance(message, AIMessage) + assert message.content == "Test response with both" + + assert message.usage_metadata is not None + assert isinstance(message.usage_metadata, dict) + assert message.usage_metadata["input_tokens"] == 2006 + assert message.usage_metadata["output_tokens"] == 450 + assert message.usage_metadata["total_tokens"] == 2456 + + assert "input_token_details" in message.usage_metadata + assert message.usage_metadata["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in message.usage_metadata + assert message.usage_metadata["output_token_details"]["reasoning"] == 200 + + def test_chat_result_backward_compatibility() -> None: """Test that responses without new fields still work.""" llm = ChatGroq(model="test-model") @@ -533,6 +708,96 @@ def test_streaming_with_usage_metadata() -> None: assert "output_token_details" not in result.usage_metadata +def test_streaming_with_reasoning_tokens() -> None: + """Test that streaming properly includes reasoning tokens in usage metadata.""" + chunk = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": "Hello", + }, + "finish_reason": None, + } + ], + "x_groq": { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 450, + "total_tokens": 550, + "completion_tokens_details": {"reasoning_tokens": 200}, + } + }, + } + + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + + assert isinstance(result, AIMessageChunk) + assert result.content == "Hello" + + assert result.usage_metadata is not None + assert isinstance(result.usage_metadata, dict) + assert result.usage_metadata["input_tokens"] == 100 + assert result.usage_metadata["output_tokens"] == 450 + assert result.usage_metadata["total_tokens"] == 550 + + assert "output_token_details" in result.usage_metadata + assert result.usage_metadata["output_token_details"]["reasoning"] == 200 + + assert "input_token_details" not in result.usage_metadata + + +def test_streaming_with_cached_and_reasoning_tokens() -> None: + """Test that streaming includes both cached and reasoning tokens.""" + chunk = { + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": "Hello", + }, + "finish_reason": None, + } + ], + "x_groq": { + "usage": { + "prompt_tokens": 2006, + "completion_tokens": 450, + "total_tokens": 2456, + "prompt_tokens_details": {"cached_tokens": 1920}, + "completion_tokens_details": {"reasoning_tokens": 200}, + } + }, + } + + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + + assert isinstance(result, AIMessageChunk) + assert result.content == "Hello" + + assert result.usage_metadata is not None + assert isinstance(result.usage_metadata, dict) + assert result.usage_metadata["input_tokens"] == 2006 + assert result.usage_metadata["output_tokens"] == 450 + assert result.usage_metadata["total_tokens"] == 2456 + + assert "input_token_details" in result.usage_metadata + assert result.usage_metadata["input_token_details"]["cache_read"] == 1920 + + assert "output_token_details" in result.usage_metadata + assert result.usage_metadata["output_token_details"]["reasoning"] == 200 + + def test_streaming_without_usage_metadata() -> None: """Test that streaming works without usage metadata (backward compatibility).""" chunk = { @@ -557,3 +822,76 @@ def test_streaming_without_usage_metadata() -> None: assert isinstance(result, AIMessageChunk) assert result.content == "Hello" assert result.usage_metadata is None + + +def test_combine_llm_outputs_with_token_details() -> None: + """Test that _combine_llm_outputs properly combines nested token details.""" + llm = ChatGroq(model="test-model") + + llm_outputs: list[dict[str, Any] | None] = [ + { + "token_usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_tokens_details": {"cached_tokens": 80}, + "completion_tokens_details": {"reasoning_tokens": 20}, + }, + "model_name": "test-model", + "system_fingerprint": "fp_123", + }, + { + "token_usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + "prompt_tokens_details": {"cached_tokens": 150}, + "completion_tokens_details": {"reasoning_tokens": 40}, + }, + "model_name": "test-model", + "system_fingerprint": "fp_123", + }, + ] + + result = llm._combine_llm_outputs(llm_outputs) + + assert result["token_usage"]["prompt_tokens"] == 300 + assert result["token_usage"]["completion_tokens"] == 150 + assert result["token_usage"]["total_tokens"] == 450 + assert result["token_usage"]["prompt_tokens_details"]["cached_tokens"] == 230 + assert result["token_usage"]["completion_tokens_details"]["reasoning_tokens"] == 60 + assert result["model_name"] == "test-model" + assert result["system_fingerprint"] == "fp_123" + + +def test_combine_llm_outputs_with_missing_details() -> None: + """Test _combine_llm_outputs when some outputs have details and others don't.""" + llm = ChatGroq(model="test-model") + + llm_outputs: list[dict[str, Any] | None] = [ + { + "token_usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + "model_name": "test-model", + }, + { + "token_usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + "completion_tokens_details": {"reasoning_tokens": 40}, + }, + "model_name": "test-model", + }, + ] + + result = llm._combine_llm_outputs(llm_outputs) + + assert result["token_usage"]["prompt_tokens"] == 300 + assert result["token_usage"]["completion_tokens"] == 150 + assert result["token_usage"]["total_tokens"] == 450 + assert result["token_usage"]["completion_tokens_details"]["reasoning_tokens"] == 40 + assert "prompt_tokens_details" not in result["token_usage"] From 5ba24f937c0d5a790b4ae75e46ea616a0b516eac Mon Sep 17 00:00:00 2001 From: MshariAlaeena Date: Thu, 6 Nov 2025 11:32:42 +0300 Subject: [PATCH 6/6] fix: correct duplicate test to cover all token details --- .../groq/tests/unit_tests/test_chat_models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 607b96fe65201..c69041cc1d8d1 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -334,23 +334,26 @@ def test_create_usage_metadata_with_all_details() -> None: """Test usage metadata with all available details.""" token_usage = { "prompt_tokens": 2006, - "completion_tokens": 300, - "total_tokens": 2306, + "completion_tokens": 450, + "total_tokens": 2456, "prompt_tokens_details": {"cached_tokens": 1920}, + "completion_tokens_details": {"reasoning_tokens": 200}, } result = _create_usage_metadata(token_usage) assert isinstance(result, dict) assert result["input_tokens"] == 2006 - assert result["output_tokens"] == 300 - assert result["total_tokens"] == 2306 + assert result["output_tokens"] == 450 + assert result["total_tokens"] == 2456 assert "input_token_details" in result assert isinstance(result["input_token_details"], dict) assert result["input_token_details"]["cache_read"] == 1920 - assert "output_token_details" not in result + assert "output_token_details" in result + assert isinstance(result["output_token_details"], dict) + assert result["output_token_details"]["reasoning"] == 200 def test_create_usage_metadata_missing_total_tokens() -> None: