Skip to content
73 changes: 56 additions & 17 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
Expand Down Expand Up @@ -704,15 +709,7 @@ def _create_chat_result(
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
message.usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": token_usage.get(
"total_tokens", input_tokens + output_tokens
),
}
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
Expand Down Expand Up @@ -752,7 +749,20 @@ def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage and v is not None:
overall_token_usage[k] += v
# Handle nested dictionaries
if isinstance(v, dict):
if k not in overall_token_usage:
overall_token_usage[k] = {}
for nested_k, nested_v in v.items():
if (
nested_k in overall_token_usage[k]
and nested_v is not None
):
overall_token_usage[k][nested_k] += nested_v
else:
overall_token_usage[k][nested_k] = nested_v
else:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
Expand Down Expand Up @@ -1303,13 +1313,7 @@ def _convert_chunk_to_message_chunk(
{k: executed_tool[k] for k in executed_tool if k != "output"}
)
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
usage_metadata = _create_usage_metadata(usage)
else:
usage_metadata = None
return AIMessageChunk(
Expand Down Expand Up @@ -1409,3 +1413,38 @@ def _lc_invalid_tool_call_to_groq_tool_call(
"arguments": invalid_tool_call["args"],
},
}


def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata:
"""Create usage metadata from Groq token usage response.

Args:
groq_token_usage: Token usage dict from Groq API response.

Returns:
Usage metadata dict with input/output token details.
"""
input_tokens = groq_token_usage.get("prompt_tokens") or 0
output_tokens = groq_token_usage.get("completion_tokens") or 0
total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens
input_token_details: dict = {
"cache_read": (groq_token_usage.get("prompt_tokens_details") or {}).get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find anywhere in Groq's docs where prompt_token_details are returned?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be input_tokens_details

Copy link
Contributor Author

@MshariAlaeena MshariAlaeena Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking, It’s actually documented here:
https://console.groq.com/docs/prompt-caching#tracking-cache-usage

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MshariAlaeena I think that might be wrong? See Slack

"cached_tokens"
),
}
output_token_details: dict = {
"reasoning": (groq_token_usage.get("completion_tokens_details") or {}).get(
"reasoning_tokens"
),
}
usage_metadata: UsageMetadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}

if filtered_input := {k: v for k, v in input_token_details.items() if v}:
usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item]
if filtered_output := {k: v for k, v in output_token_details.items() if v}:
usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item]
return usage_metadata
Loading