Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import (
InputTokenDetails,
UsageMetadata,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
Expand Down Expand Up @@ -704,15 +708,7 @@ def _create_chat_result(
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
message.usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": token_usage.get(
"total_tokens", input_tokens + output_tokens
),
}
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
Expand Down Expand Up @@ -1303,13 +1299,7 @@ def _convert_chunk_to_message_chunk(
{k: executed_tool[k] for k in executed_tool if k != "output"}
)
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
usage_metadata = _create_usage_metadata(usage)
else:
usage_metadata = None
return AIMessageChunk(
Expand Down Expand Up @@ -1409,3 +1399,31 @@ def _lc_invalid_tool_call_to_groq_tool_call(
"arguments": invalid_tool_call["args"],
},
}


def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata:
"""Create usage metadata from Groq token usage response.

Args:
groq_token_usage: Token usage dict from Groq API response.

Returns:
Usage metadata dict with input/output token details.
"""
input_tokens = groq_token_usage.get("prompt_tokens") or 0
output_tokens = groq_token_usage.get("completion_tokens") or 0
total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens
input_token_details: dict = {
"cache_read": (groq_token_usage.get("prompt_tokens_details") or {}).get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find anywhere in Groq's docs where prompt_token_details are returned?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be input_tokens_details

Copy link
Author

@MshariAlaeena MshariAlaeena Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking, It’s actually documented here:
https://console.groq.com/docs/prompt-caching#tracking-cache-usage

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MshariAlaeena I think that might be wrong? See Slack

"cached_tokens"
),
}
usage_metadata: UsageMetadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}

if filtered_input := {k: v for k, v in input_token_details.items() if v}:
usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item]
return usage_metadata
276 changes: 275 additions & 1 deletion libs/partners/groq/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)

from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
from langchain_groq.chat_models import (
ChatGroq,
_convert_chunk_to_message_chunk,
_convert_dict_to_message,
_create_usage_metadata,
)

if "GROQ_API_KEY" not in os.environ:
os.environ["GROQ_API_KEY"] = "fake-key"
Expand Down Expand Up @@ -283,3 +289,271 @@ def test_groq_serialization() -> None:

# Ensure a None was preserved
assert llm.groq_api_base == llm2.groq_api_base


def test_create_usage_metadata_basic() -> None:
"""Test basic usage metadata creation without details."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result


def test_create_usage_metadata_with_cached_tokens() -> None:
"""Test usage metadata with prompt caching."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 300
assert result["total_tokens"] == 2306
assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920
assert "output_token_details" not in result


def test_create_usage_metadata_with_all_details() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears identical to test_create_usage_metadata_with_cached_tokens?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I forgot to add the reasoning fields. Now fixed to test both cached and reasoning tokens together

"""Test usage metadata with all available details."""
token_usage = {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}

result = _create_usage_metadata(token_usage)

assert isinstance(result, dict)
assert result["input_tokens"] == 2006
assert result["output_tokens"] == 300
assert result["total_tokens"] == 2306

assert "input_token_details" in result
assert isinstance(result["input_token_details"], dict)
assert result["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in result


def test_create_usage_metadata_missing_total_tokens() -> None:
"""Test that total_tokens is calculated when missing."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150


def test_create_usage_metadata_empty_details() -> None:
"""Test that empty detail dicts don't create token detail objects."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": {},
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result
assert "output_token_details" not in result


def test_create_usage_metadata_zero_cached_tokens() -> None:
"""Test that zero cached tokens are not included (falsy)."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": {"cached_tokens": 0},
}

result = _create_usage_metadata(token_usage)

assert result["input_tokens"] == 100
assert result["output_tokens"] == 50
assert result["total_tokens"] == 150
assert "input_token_details" not in result


def test_chat_result_with_usage_metadata() -> None:
"""Test that _create_chat_result properly includes usage metadata."""
llm = ChatGroq(model="test-model")

mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
},
}

result = llm._create_chat_result(mock_response, {})

assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == "Test response"

assert message.usage_metadata is not None
assert isinstance(message.usage_metadata, dict)
assert message.usage_metadata["input_tokens"] == 2006
assert message.usage_metadata["output_tokens"] == 300
assert message.usage_metadata["total_tokens"] == 2306

assert "input_token_details" in message.usage_metadata
assert message.usage_metadata["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in message.usage_metadata


def test_chat_result_backward_compatibility() -> None:
"""Test that responses without new fields still work."""
llm = ChatGroq(model="test-model")

mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
}

result = llm._create_chat_result(mock_response, {})

assert len(result.generations) == 1
message = result.generations[0].message
assert isinstance(message, AIMessage)

assert message.usage_metadata is not None
assert message.usage_metadata["input_tokens"] == 100
assert message.usage_metadata["output_tokens"] == 50
assert message.usage_metadata["total_tokens"] == 150

assert "input_token_details" not in message.usage_metadata
assert "output_token_details" not in message.usage_metadata


def test_streaming_with_usage_metadata() -> None:
"""Test that streaming properly includes usage metadata."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
"x_groq": {
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {"cached_tokens": 1920},
}
},
}

result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)

assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"

assert result.usage_metadata is not None
assert isinstance(result.usage_metadata, dict)
assert result.usage_metadata["input_tokens"] == 2006
assert result.usage_metadata["output_tokens"] == 300
assert result.usage_metadata["total_tokens"] == 2306

assert "input_token_details" in result.usage_metadata
assert result.usage_metadata["input_token_details"]["cache_read"] == 1920

assert "output_token_details" not in result.usage_metadata


def test_streaming_without_usage_metadata() -> None:
"""Test that streaming works without usage metadata (backward compatibility)."""
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
}

result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)

assert isinstance(result, AIMessageChunk)
assert result.content == "Hello"
assert result.usage_metadata is None