|
37 | 37 | ToolMessage, |
38 | 38 | ToolMessageChunk, |
39 | 39 | ) |
| 40 | +from langchain_core.messages.ai import ( |
| 41 | + InputTokenDetails, |
| 42 | + OutputTokenDetails, |
| 43 | + UsageMetadata, |
| 44 | +) |
40 | 45 | from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser |
41 | 46 | from langchain_core.output_parsers.base import OutputParserLike |
42 | 47 | from langchain_core.output_parsers.openai_tools import ( |
@@ -726,15 +731,7 @@ def _create_chat_result( |
726 | 731 | for res in response["choices"]: |
727 | 732 | message = _convert_dict_to_message(res["message"]) |
728 | 733 | if token_usage and isinstance(message, AIMessage): |
729 | | - input_tokens = token_usage.get("prompt_tokens", 0) |
730 | | - output_tokens = token_usage.get("completion_tokens", 0) |
731 | | - message.usage_metadata = { |
732 | | - "input_tokens": input_tokens, |
733 | | - "output_tokens": output_tokens, |
734 | | - "total_tokens": token_usage.get( |
735 | | - "total_tokens", input_tokens + output_tokens |
736 | | - ), |
737 | | - } |
| 734 | + message.usage_metadata = _create_usage_metadata(token_usage) |
738 | 735 | generation_info = {"finish_reason": res.get("finish_reason")} |
739 | 736 | if "logprobs" in res: |
740 | 737 | generation_info["logprobs"] = res["logprobs"] |
@@ -774,7 +771,20 @@ def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: |
774 | 771 | if token_usage is not None: |
775 | 772 | for k, v in token_usage.items(): |
776 | 773 | if k in overall_token_usage and v is not None: |
777 | | - overall_token_usage[k] += v |
| 774 | + # Handle nested dictionaries |
| 775 | + if isinstance(v, dict): |
| 776 | + if k not in overall_token_usage: |
| 777 | + overall_token_usage[k] = {} |
| 778 | + for nested_k, nested_v in v.items(): |
| 779 | + if ( |
| 780 | + nested_k in overall_token_usage[k] |
| 781 | + and nested_v is not None |
| 782 | + ): |
| 783 | + overall_token_usage[k][nested_k] += nested_v |
| 784 | + else: |
| 785 | + overall_token_usage[k][nested_k] = nested_v |
| 786 | + else: |
| 787 | + overall_token_usage[k] += v |
778 | 788 | else: |
779 | 789 | overall_token_usage[k] = v |
780 | 790 | if system_fingerprint is None: |
@@ -1329,13 +1339,7 @@ def _convert_chunk_to_message_chunk( |
1329 | 1339 | {k: executed_tool[k] for k in executed_tool if k != "output"} |
1330 | 1340 | ) |
1331 | 1341 | if usage := (chunk.get("x_groq") or {}).get("usage"): |
1332 | | - input_tokens = usage.get("prompt_tokens", 0) |
1333 | | - output_tokens = usage.get("completion_tokens", 0) |
1334 | | - usage_metadata = { |
1335 | | - "input_tokens": input_tokens, |
1336 | | - "output_tokens": output_tokens, |
1337 | | - "total_tokens": usage.get("total_tokens", input_tokens + output_tokens), |
1338 | | - } |
| 1342 | + usage_metadata = _create_usage_metadata(usage) |
1339 | 1343 | else: |
1340 | 1344 | usage_metadata = None |
1341 | 1345 | return AIMessageChunk( |
@@ -1435,3 +1439,59 @@ def _lc_invalid_tool_call_to_groq_tool_call( |
1435 | 1439 | "arguments": invalid_tool_call["args"], |
1436 | 1440 | }, |
1437 | 1441 | } |
| 1442 | + |
| 1443 | + |
| 1444 | +def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata: |
| 1445 | + """Create usage metadata from Groq token usage response. |
| 1446 | +
|
| 1447 | + Args: |
| 1448 | + groq_token_usage: Token usage dict from Groq API response. |
| 1449 | +
|
| 1450 | + Returns: |
| 1451 | + Usage metadata dict with input/output token details. |
| 1452 | + """ |
| 1453 | + # Support both formats: new Responses API uses "input_tokens", |
| 1454 | + # Chat Completions API uses "prompt_tokens" |
| 1455 | + input_tokens = ( |
| 1456 | + groq_token_usage.get("input_tokens") |
| 1457 | + or groq_token_usage.get("prompt_tokens") |
| 1458 | + or 0 |
| 1459 | + ) |
| 1460 | + output_tokens = ( |
| 1461 | + groq_token_usage.get("output_tokens") |
| 1462 | + or groq_token_usage.get("completion_tokens") |
| 1463 | + or 0 |
| 1464 | + ) |
| 1465 | + total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens |
| 1466 | + |
| 1467 | + # Support both formats for token details: |
| 1468 | + # Responses API uses "*_tokens_details", Chat Completions API might use |
| 1469 | + # "prompt_token_details" |
| 1470 | + input_details_dict = ( |
| 1471 | + groq_token_usage.get("input_tokens_details") |
| 1472 | + or groq_token_usage.get("prompt_tokens_details") |
| 1473 | + or {} |
| 1474 | + ) |
| 1475 | + output_details_dict = ( |
| 1476 | + groq_token_usage.get("output_tokens_details") |
| 1477 | + or groq_token_usage.get("completion_tokens_details") |
| 1478 | + or {} |
| 1479 | + ) |
| 1480 | + |
| 1481 | + input_token_details: dict = { |
| 1482 | + "cache_read": input_details_dict.get("cached_tokens"), |
| 1483 | + } |
| 1484 | + output_token_details: dict = { |
| 1485 | + "reasoning": output_details_dict.get("reasoning_tokens"), |
| 1486 | + } |
| 1487 | + usage_metadata: UsageMetadata = { |
| 1488 | + "input_tokens": input_tokens, |
| 1489 | + "output_tokens": output_tokens, |
| 1490 | + "total_tokens": total_tokens, |
| 1491 | + } |
| 1492 | + |
| 1493 | + if filtered_input := {k: v for k, v in input_token_details.items() if v}: |
| 1494 | + usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item] |
| 1495 | + if filtered_output := {k: v for k, v in output_token_details.items() if v}: |
| 1496 | + usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item] |
| 1497 | + return usage_metadata |
0 commit comments