Skip to content

Commit f864566

Browse files
committed
Updated recording of token usage
1 parent a7b2d67 commit f864566

File tree

7 files changed

+169
-76
lines changed

7 files changed

+169
-76
lines changed

sentry_sdk/ai/monitoring.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,37 @@ async def async_wrapped(*args, **kwargs):
9696

9797

9898
def record_token_usage(
99-
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
99+
span,
100+
input_tokens=None,
101+
input_tokens_cached=None,
102+
output_tokens=None,
103+
output_tokens_reasoning=None,
104+
total_tokens=None,
100105
):
101-
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
106+
# type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
107+
108+
# TODO: move pipeline name elsewhere
102109
ai_pipeline_name = get_ai_pipeline_name()
103110
if ai_pipeline_name:
104111
span.set_data(SPANDATA.AI_PIPELINE_NAME, ai_pipeline_name)
105-
if prompt_tokens is not None:
106-
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
107-
if completion_tokens is not None:
108-
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
109-
if (
110-
total_tokens is None
111-
and prompt_tokens is not None
112-
and completion_tokens is not None
113-
):
114-
total_tokens = prompt_tokens + completion_tokens
112+
113+
if input_tokens is not None:
114+
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
115+
116+
if input_tokens_cached is not None:
117+
span.set_data(
118+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
119+
input_tokens_cached,
120+
)
121+
122+
if output_tokens is not None:
123+
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
124+
125+
if output_tokens_reasoning is not None:
126+
span.set_data(
127+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
128+
output_tokens_reasoning,
129+
)
130+
115131
if total_tokens is not None:
116-
span.set_measurement("ai_total_tokens_used", total_tokens)
132+
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)

sentry_sdk/integrations/anthropic.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ def _calculate_token_usage(result, span):
6565
output_tokens = usage.output_tokens
6666

6767
total_tokens = input_tokens + output_tokens
68-
record_token_usage(span, input_tokens, output_tokens, total_tokens)
68+
69+
record_token_usage(
70+
span,
71+
input_tokens=input_tokens,
72+
output_tokens=output_tokens,
73+
total_tokens=total_tokens,
74+
)
6975

7076

7177
def _get_responses(content):
@@ -126,7 +132,12 @@ def _add_ai_data_to_span(
126132
[{"type": "text", "text": complete_message}],
127133
)
128134
total_tokens = input_tokens + output_tokens
129-
record_token_usage(span, input_tokens, output_tokens, total_tokens)
135+
record_token_usage(
136+
span,
137+
input_tokens=input_tokens,
138+
output_tokens=output_tokens,
139+
total_tokens=total_tokens,
140+
)
130141
span.set_data(SPANDATA.AI_STREAMING, True)
131142

132143

sentry_sdk/integrations/cohere.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def collect_chat_response_fields(span, res, include_pii):
116116
if hasattr(res.meta, "billed_units"):
117117
record_token_usage(
118118
span,
119-
prompt_tokens=res.meta.billed_units.input_tokens,
120-
completion_tokens=res.meta.billed_units.output_tokens,
119+
input_tokens=res.meta.billed_units.input_tokens,
120+
output_tokens=res.meta.billed_units.output_tokens,
121121
)
122122
elif hasattr(res.meta, "tokens"):
123123
record_token_usage(
124124
span,
125-
prompt_tokens=res.meta.tokens.input_tokens,
126-
completion_tokens=res.meta.tokens.output_tokens,
125+
input_tokens=res.meta.tokens.input_tokens,
126+
output_tokens=res.meta.tokens.output_tokens,
127127
)
128128

129129
if hasattr(res.meta, "warnings"):
@@ -262,7 +262,7 @@ def new_embed(*args, **kwargs):
262262
):
263263
record_token_usage(
264264
span,
265-
prompt_tokens=res.meta.billed_units.input_tokens,
265+
input_tokens=res.meta.billed_units.input_tokens,
266266
total_tokens=res.meta.billed_units.input_tokens,
267267
)
268268
return res

sentry_sdk/integrations/huggingface_hub.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def new_text_generation(*args, **kwargs):
111111
[res.generated_text],
112112
)
113113
if res.details is not None and res.details.generated_tokens > 0:
114-
record_token_usage(span, total_tokens=res.details.generated_tokens)
114+
record_token_usage(
115+
span,
116+
total_tokens=res.details.generated_tokens,
117+
)
115118
span.__exit__(None, None, None)
116119
return res
117120

@@ -145,7 +148,10 @@ def new_details_iterator():
145148
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
146149
)
147150
if tokens_used > 0:
148-
record_token_usage(span, total_tokens=tokens_used)
151+
record_token_usage(
152+
span,
153+
total_tokens=tokens_used,
154+
)
149155
span.__exit__(None, None, None)
150156

151157
return new_details_iterator()

sentry_sdk/integrations/langchain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,15 @@ def on_llm_end(self, response, *, run_id, **kwargs):
278278
if token_usage:
279279
record_token_usage(
280280
span_data.span,
281-
token_usage.get("prompt_tokens"),
282-
token_usage.get("completion_tokens"),
283-
token_usage.get("total_tokens"),
281+
input_tokens=token_usage.get("prompt_tokens"),
282+
output_tokens=token_usage.get("completion_tokens"),
283+
total_tokens=token_usage.get("total_tokens"),
284284
)
285285
else:
286286
record_token_usage(
287287
span_data.span,
288-
span_data.num_prompt_tokens,
289-
span_data.num_completion_tokens,
288+
input_tokens=span_data.num_prompt_tokens,
289+
output_tokens=span_data.num_completion_tokens,
290290
)
291291

292292
self._exit_span(span_data, run_id)

sentry_sdk/integrations/openai.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,48 +70,75 @@ def _capture_exception(exc):
7070
sentry_sdk.capture_event(event, hint=hint)
7171

7272

73-
def _calculate_chat_completion_usage(
73+
def _get_usage(usage, names):
74+
# type: (Any, List[str]) -> int
75+
for name in names:
76+
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
77+
return getattr(usage, name)
78+
return 0
79+
80+
81+
def _calculate_token_usage(
7482
messages, response, span, streaming_message_responses, count_tokens
7583
):
7684
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
77-
completion_tokens = 0 # type: Optional[int]
78-
prompt_tokens = 0 # type: Optional[int]
85+
input_tokens = 0 # type: Optional[int]
86+
input_tokens_cached = 0 # type: Optional[int]
87+
output_tokens = 0 # type: Optional[int]
88+
output_tokens_reasoning = 0 # type: Optional[int]
7989
total_tokens = 0 # type: Optional[int]
90+
8091
if hasattr(response, "usage"):
81-
if hasattr(response.usage, "completion_tokens") and isinstance(
82-
response.usage.completion_tokens, int
83-
):
84-
completion_tokens = response.usage.completion_tokens
85-
if hasattr(response.usage, "prompt_tokens") and isinstance(
86-
response.usage.prompt_tokens, int
87-
):
88-
prompt_tokens = response.usage.prompt_tokens
89-
if hasattr(response.usage, "total_tokens") and isinstance(
90-
response.usage.total_tokens, int
91-
):
92-
total_tokens = response.usage.total_tokens
92+
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
93+
if hasattr(response.usage, "input_tokens_details"):
94+
input_tokens_cached = _get_usage(
95+
response.usage.input_tokens_details, ["cached_tokens"]
96+
)
9397

94-
if prompt_tokens == 0:
98+
output_tokens = _get_usage(
99+
response.usage, ["output_tokens", "completion_tokens"]
100+
)
101+
if hasattr(response.usage, "output_tokens_details"):
102+
output_tokens_reasoning = _get_usage(
103+
response.usage.output_tokens_details, ["reasoning_tokens"]
104+
)
105+
106+
total_tokens = _get_usage(response.usage, ["total_tokens"])
107+
108+
# Manually count tokens
109+
# TODO: check for responses API
110+
if input_tokens == 0:
95111
for message in messages:
96112
if "content" in message:
97-
prompt_tokens += count_tokens(message["content"])
113+
input_tokens += count_tokens(message["content"])
98114

99-
if completion_tokens == 0:
115+
# TODO: check for responses API
116+
if output_tokens == 0:
100117
if streaming_message_responses is not None:
101118
for message in streaming_message_responses:
102-
completion_tokens += count_tokens(message)
119+
output_tokens += count_tokens(message)
103120
elif hasattr(response, "choices"):
104121
for choice in response.choices:
105122
if hasattr(choice, "message"):
106-
completion_tokens += count_tokens(choice.message)
107-
108-
if prompt_tokens == 0:
109-
prompt_tokens = None
110-
if completion_tokens == 0:
111-
completion_tokens = None
112-
if total_tokens == 0:
113-
total_tokens = None
114-
record_token_usage(span, prompt_tokens, completion_tokens, total_tokens)
123+
output_tokens += count_tokens(choice.message)
124+
125+
# Do not set token data if it is 0
126+
input_tokens = None if input_tokens == 0 else input_tokens
127+
input_tokens_cached = None if input_tokens_cached == 0 else input_tokens_cached
128+
output_tokens = None if output_tokens == 0 else output_tokens
129+
output_tokens_reasoning = (
130+
None if output_tokens_reasoning == 0 else output_tokens_reasoning
131+
)
132+
total_tokens = None if total_tokens == 0 else total_tokens
133+
134+
record_token_usage(
135+
span,
136+
input_tokens=input_tokens,
137+
input_tokens_cached=input_tokens_cached,
138+
output_tokens=output_tokens,
139+
output_tokens_reasoning=output_tokens_reasoning,
140+
total_tokens=total_tokens,
141+
)
115142

116143

117144
def _new_chat_completion_common(f, *args, **kwargs):
@@ -158,9 +185,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
158185
SPANDATA.AI_RESPONSES,
159186
list(map(lambda x: x.message, res.choices)),
160187
)
161-
_calculate_chat_completion_usage(
162-
messages, res, span, None, integration.count_tokens
163-
)
188+
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
164189
span.__exit__(None, None, None)
165190
elif hasattr(res, "_iterator"):
166191
data_buf: list[list[str]] = [] # one for each choice
@@ -191,7 +216,7 @@ def new_iterator():
191216
set_data_normalized(
192217
span, SPANDATA.AI_RESPONSES, all_responses
193218
)
194-
_calculate_chat_completion_usage(
219+
_calculate_token_usage(
195220
messages,
196221
res,
197222
span,
@@ -224,7 +249,7 @@ async def new_iterator_async():
224249
set_data_normalized(
225250
span, SPANDATA.AI_RESPONSES, all_responses
226251
)
227-
_calculate_chat_completion_usage(
252+
_calculate_token_usage(
228253
messages,
229254
res,
230255
span,

0 commit comments

Comments
 (0)