Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Set `run_inline = True` on the tracer so LangChain callbacks run inline for correct OpenTelemetry context propagation
([#148](https://github.com/alibaba/loongsuite-python-agent/pull/148))
- Improved token usage extraction to support multiple LangChain/LLM provider formats
([#148](https://github.com/alibaba/loongsuite-python-agent/pull/148))

## Version 0.2.0 (2026-03-12)

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(_schema_format="original+chat", **kwargs)
# We need run callback inline so that propagate the context correctly.
self.run_inline = True
self._handler = handler
self._tracer = get_tracer(
__name__,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,15 @@ def _parse_token_usage_dict(token_usage: Any) -> tuple[int | None, int | None]:
"""Parse a token_usage/usage dict into (input_tokens, output_tokens)."""
if not isinstance(token_usage, dict):
return None, None
inp = token_usage.get("prompt_tokens") or token_usage.get("input_tokens")
out = token_usage.get("completion_tokens") or token_usage.get(
"output_tokens"
inp = (
token_usage.get("prompt_tokens")
or token_usage.get("PromptTokens")
or token_usage.get("input_tokens")
)
out = (
token_usage.get("completion_tokens")
or token_usage.get("CompletionTokens")
or token_usage.get("output_tokens")
)
return (
int(inp) if inp is not None else None,
Expand All @@ -334,7 +340,8 @@ def _extract_token_usage(run: Any) -> tuple[int | None, int | None]:
Tries multiple LangChain formats in order:
1. outputs["llm_output"]["token_usage"] or ["usage"]
2. generations[i][j]["generation_info"]["token_usage"] or ["usage"]
3. generations[i][j]["message"].response_metadata or ["kwargs"]["response_metadata"]
3. generations[i][j]["message"].response_metadata["token_usage"] or generations[i][j]["message"].response_metadata["usage"] or generations[i][j]["message"]["kwargs"]["response_metadata"]["token_usage"] or generations[i][j]["message"]["kwargs"]["response_metadata"]["usage"]
4. generations[i][j]["message"].usage_metadata or generations[i][j]["message"]["kwargs"]["usage_metadata"]
"""
outputs = getattr(run, "outputs", None) or {}

Expand All @@ -348,7 +355,8 @@ def _extract_token_usage(run: Any) -> tuple[int | None, int | None]:
return inp, out

# 2. Fallback: generations[][].generation_info["token_usage"] or ["usage"]
# 3. Fallback: generations[][].message.response_metadata["token_usage"]
# 3. Fallback: generations[][].message.response_metadata["token_usage"] or generations[][].message.response_metadata["usage"] or generations[][].message["kwargs"]["response_metadata"]["token_usage"] or generations[][].message["kwargs"]["response_metadata"]["usage"]
# 4. Fallback: generations[][].message.usage_metadata or generations[][].message["kwargs"]["usage_metadata"]
for gen_list in outputs.get("generations") or []:
if not isinstance(gen_list, list):
continue
Expand All @@ -363,10 +371,10 @@ def _extract_token_usage(run: Any) -> tuple[int | None, int | None]:
inp, out = _parse_token_usage_dict(token_usage)
if inp is not None or out is not None:
return inp, out
# Try message.response_metadata (serialized: kwargs.response_metadata)
msg = gen.get("message")
if msg is None:
continue
# Try message.response_metadata (serialized: kwargs.response_metadata)
if isinstance(msg, dict):
metadata = (msg.get("kwargs") or {}).get(
"response_metadata"
Expand All @@ -380,6 +388,16 @@ def _extract_token_usage(run: Any) -> tuple[int | None, int | None]:
inp, out = _parse_token_usage_dict(token_usage)
if inp is not None or out is not None:
return inp, out
# Try message.usage_metadata (serialized: kwargs.usage_metadata)
if isinstance(msg, dict):
metadata = (msg.get("kwargs") or {}).get(
"usage_metadata"
) or {}
else:
metadata = getattr(msg, "usage_metadata", None) or {}
inp, out = _parse_token_usage_dict(metadata)
if inp is not None or out is not None:
return inp, out

return None, None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_extract_response_model,
_extract_token_usage,
_extract_tool_definitions,
_parse_token_usage_dict,
_safe_json,
)
from opentelemetry.util.genai.types import (
Expand Down Expand Up @@ -298,6 +299,56 @@ def test_empty_outputs(self):
assert _extract_llm_output_messages(run) == []


class TestParseTokenUsageDict:
"""Unit tests for _parse_token_usage_dict."""

def test_prompt_completion_tokens(self):
"""Standard OpenAI format: prompt_tokens, completion_tokens."""
inp, out = _parse_token_usage_dict(
{"prompt_tokens": 10, "completion_tokens": 20}
)
assert inp == 10
assert out == 20

def test_input_output_tokens(self):
"""Anthropic/Claude format: input_tokens, output_tokens."""
inp, out = _parse_token_usage_dict(
{"input_tokens": 5, "output_tokens": 15}
)
assert inp == 5
assert out == 15

def test_azure_format_pascal_case(self):
"""Azure format: PromptTokens, CompletionTokens (PascalCase)."""
inp, out = _parse_token_usage_dict(
{"PromptTokens": 100, "CompletionTokens": 50}
)
assert inp == 100
assert out == 50

def test_partial_input_only(self):
"""Only input tokens present, output is None."""
inp, out = _parse_token_usage_dict({"prompt_tokens": 42})
assert inp == 42
assert out is None

def test_partial_output_only(self):
"""Only output tokens present, input is None."""
inp, out = _parse_token_usage_dict({"completion_tokens": 8})
assert inp is None
assert out == 8

def test_non_dict_returns_none(self):
"""Non-dict input returns (None, None)."""
assert _parse_token_usage_dict(None) == (None, None)
assert _parse_token_usage_dict("not a dict") == (None, None)
assert _parse_token_usage_dict([]) == (None, None)

def test_empty_dict(self):
"""Empty dict returns (None, None)."""
assert _parse_token_usage_dict({}) == (None, None)


class TestExtractTokenUsage:
def test_from_llm_output(self):
run = _FakeRun(
Expand Down Expand Up @@ -459,6 +510,100 @@ def test_llm_output_takes_precedence(self):
assert inp == 1
assert out == 2

def test_from_message_usage_metadata_dict(self):
"""Token usage may be in message.kwargs.usage_metadata (serialized format)."""
run = _FakeRun(
outputs={
"generations": [
[
{
"text": "Response",
"message": {
"kwargs": {
"content": "Response",
"usage_metadata": {
"input_tokens": 30,
"output_tokens": 12,
},
}
},
}
]
]
}
)
inp, out = _extract_token_usage(run)
assert inp == 30
assert out == 12

def test_from_message_usage_metadata_object(self):
"""Token usage may be in message.usage_metadata (object format)."""

class _FakeMessage:
usage_metadata = {
"prompt_tokens": 80,
"completion_tokens": 20,
}

run = _FakeRun(
outputs={
"generations": [
[
{
"text": "Response",
"message": _FakeMessage(),
}
]
]
}
)
inp, out = _extract_token_usage(run)
assert inp == 80
assert out == 20

def test_llm_output_azure_format(self):
"""llm_output may use Azure-style PromptTokens/CompletionTokens."""
run = _FakeRun(
outputs={
"llm_output": {
"token_usage": {
"PromptTokens": 200,
"CompletionTokens": 75,
}
}
}
)
inp, out = _extract_token_usage(run)
assert inp == 200
assert out == 75

def test_response_metadata_usage_key(self):
"""response_metadata may use 'usage' key instead of 'token_usage'."""
run = _FakeRun(
outputs={
"generations": [
[
{
"text": "Hi",
"message": {
"kwargs": {
"response_metadata": {
"usage": {
"prompt_tokens": 15,
"completion_tokens": 7,
}
},
}
},
}
]
]
}
)
inp, out = _extract_token_usage(run)
assert inp == 15
assert out == 7

def test_no_token_usage(self):
run = _FakeRun(outputs={})
inp, out = _extract_token_usage(run)
Expand Down
Loading