Skip to content

Commit fc2a794

Browse files
committed
Merge branch 'antonpirker/openai-overhaul' into antonpirker/openai-responses-api
2 parents 281005f + 2ccab61 commit fc2a794

File tree

7 files changed

+165
-73
lines changed

7 files changed

+165
-73
lines changed

sentry_sdk/ai/monitoring.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,40 @@ async def async_wrapped(*args, **kwargs):
9696

9797

9898
def record_token_usage(
99-
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
99+
span,
100+
input_tokens=None,
101+
input_tokens_cached=None,
102+
output_tokens=None,
103+
output_tokens_reasoning=None,
104+
total_tokens=None,
100105
):
101-
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
106+
# type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
107+
108+
# TODO: move pipeline name elsewhere
102109
ai_pipeline_name = get_ai_pipeline_name()
103110
if ai_pipeline_name:
104111
span.set_data(SPANDATA.AI_PIPELINE_NAME, ai_pipeline_name)
105112

106-
if prompt_tokens is not None:
107-
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens)
113+
if input_tokens is not None:
114+
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
115+
116+
if input_tokens_cached is not None:
117+
span.set_data(
118+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
119+
input_tokens_cached,
120+
)
121+
122+
if output_tokens is not None:
123+
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
108124

109-
if completion_tokens is not None:
110-
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens)
125+
if output_tokens_reasoning is not None:
126+
span.set_data(
127+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
128+
output_tokens_reasoning,
129+
)
111130

112-
if (
113-
total_tokens is None
114-
and prompt_tokens is not None
115-
and completion_tokens is not None
116-
):
117-
total_tokens = prompt_tokens + completion_tokens
131+
if total_tokens is None and input_tokens is not None and output_tokens is not None:
132+
total_tokens = input_tokens + output_tokens
118133

119134
if total_tokens is not None:
120135
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)

sentry_sdk/integrations/anthropic.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ def _calculate_token_usage(result, span):
6565
output_tokens = usage.output_tokens
6666

6767
total_tokens = input_tokens + output_tokens
68-
record_token_usage(span, input_tokens, output_tokens, total_tokens)
68+
69+
record_token_usage(
70+
span,
71+
input_tokens=input_tokens,
72+
output_tokens=output_tokens,
73+
total_tokens=total_tokens,
74+
)
6975

7076

7177
def _get_responses(content):
@@ -126,7 +132,12 @@ def _add_ai_data_to_span(
126132
[{"type": "text", "text": complete_message}],
127133
)
128134
total_tokens = input_tokens + output_tokens
129-
record_token_usage(span, input_tokens, output_tokens, total_tokens)
135+
record_token_usage(
136+
span,
137+
input_tokens=input_tokens,
138+
output_tokens=output_tokens,
139+
total_tokens=total_tokens,
140+
)
130141
span.set_data(SPANDATA.AI_STREAMING, True)
131142

132143

sentry_sdk/integrations/cohere.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def collect_chat_response_fields(span, res, include_pii):
116116
if hasattr(res.meta, "billed_units"):
117117
record_token_usage(
118118
span,
119-
prompt_tokens=res.meta.billed_units.input_tokens,
120-
completion_tokens=res.meta.billed_units.output_tokens,
119+
input_tokens=res.meta.billed_units.input_tokens,
120+
output_tokens=res.meta.billed_units.output_tokens,
121121
)
122122
elif hasattr(res.meta, "tokens"):
123123
record_token_usage(
124124
span,
125-
prompt_tokens=res.meta.tokens.input_tokens,
126-
completion_tokens=res.meta.tokens.output_tokens,
125+
input_tokens=res.meta.tokens.input_tokens,
126+
output_tokens=res.meta.tokens.output_tokens,
127127
)
128128

129129
if hasattr(res.meta, "warnings"):
@@ -262,7 +262,7 @@ def new_embed(*args, **kwargs):
262262
):
263263
record_token_usage(
264264
span,
265-
prompt_tokens=res.meta.billed_units.input_tokens,
265+
input_tokens=res.meta.billed_units.input_tokens,
266266
total_tokens=res.meta.billed_units.input_tokens,
267267
)
268268
return res

sentry_sdk/integrations/huggingface_hub.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def new_text_generation(*args, **kwargs):
111111
[res.generated_text],
112112
)
113113
if res.details is not None and res.details.generated_tokens > 0:
114-
record_token_usage(span, total_tokens=res.details.generated_tokens)
114+
record_token_usage(
115+
span,
116+
total_tokens=res.details.generated_tokens,
117+
)
115118
span.__exit__(None, None, None)
116119
return res
117120

@@ -145,7 +148,10 @@ def new_details_iterator():
145148
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
146149
)
147150
if tokens_used > 0:
148-
record_token_usage(span, total_tokens=tokens_used)
151+
record_token_usage(
152+
span,
153+
total_tokens=tokens_used,
154+
)
149155
span.__exit__(None, None, None)
150156

151157
return new_details_iterator()

sentry_sdk/integrations/langchain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,15 @@ def on_llm_end(self, response, *, run_id, **kwargs):
279279
if token_usage:
280280
record_token_usage(
281281
span_data.span,
282-
token_usage.get("prompt_tokens"),
283-
token_usage.get("completion_tokens"),
284-
token_usage.get("total_tokens"),
282+
input_tokens=token_usage.get("prompt_tokens"),
283+
output_tokens=token_usage.get("completion_tokens"),
284+
total_tokens=token_usage.get("total_tokens"),
285285
)
286286
else:
287287
record_token_usage(
288288
span_data.span,
289-
span_data.num_prompt_tokens,
290-
span_data.num_completion_tokens,
289+
input_tokens=span_data.num_prompt_tokens,
290+
output_tokens=span_data.num_completion_tokens,
291291
)
292292

293293
self._exit_span(span_data, run_id)

sentry_sdk/integrations/openai.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,49 @@ def _capture_exception(exc):
7373
sentry_sdk.capture_event(event, hint=hint)
7474

7575

76-
def _calculate_chat_completion_usage(
76+
def _get_usage(usage, names):
77+
# type: (Any, List[str]) -> int
78+
for name in names:
79+
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
80+
return getattr(usage, name)
81+
return 0
82+
83+
84+
def _calculate_token_usage(
7785
messages, response, span, streaming_message_responses, count_tokens
7886
):
7987
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
8088
input_tokens = 0 # type: Optional[int]
89+
input_tokens_cached = 0 # type: Optional[int]
8190
output_tokens = 0 # type: Optional[int]
91+
output_tokens_reasoning = 0 # type: Optional[int]
8292
total_tokens = 0 # type: Optional[int]
93+
8394
if hasattr(response, "usage"):
84-
if hasattr(response.usage, "input_tokens") and isinstance(
85-
response.usage.input_tokens, int
86-
):
87-
input_tokens = response.usage.input_tokens
95+
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
96+
if hasattr(response.usage, "input_tokens_details"):
97+
input_tokens_cached = _get_usage(
98+
response.usage.input_tokens_details, ["cached_tokens"]
99+
)
88100

89-
if hasattr(response.usage, "output_tokens") and isinstance(
90-
response.usage.output_tokens, int
91-
):
92-
output_tokens = response.usage.output_tokens
101+
output_tokens = _get_usage(
102+
response.usage, ["output_tokens", "completion_tokens"]
103+
)
104+
if hasattr(response.usage, "output_tokens_details"):
105+
output_tokens_reasoning = _get_usage(
106+
response.usage.output_tokens_details, ["reasoning_tokens"]
107+
)
93108

94-
if hasattr(response.usage, "total_tokens") and isinstance(
95-
response.usage.total_tokens, int
96-
):
97-
total_tokens = response.usage.total_tokens
109+
total_tokens = _get_usage(response.usage, ["total_tokens"])
98110

111+
# Manually count tokens
112+
# TODO: when implementing responses API, check for responses API
99113
if input_tokens == 0:
100114
for message in messages:
101115
if "content" in message:
102116
input_tokens += count_tokens(message["content"])
103117

118+
# TODO: when implementing responses API, check for responses API
104119
if output_tokens == 0:
105120
if streaming_message_responses is not None:
106121
for message in streaming_message_responses:
@@ -110,13 +125,21 @@ def _calculate_chat_completion_usage(
110125
if hasattr(choice, "message"):
111126
output_tokens += count_tokens(choice.message)
112127

113-
if input_tokens == 0:
114-
input_tokens = None
115-
if output_tokens == 0:
116-
output_tokens = None
117-
if total_tokens == 0:
118-
total_tokens = None
119-
record_token_usage(span, input_tokens, output_tokens, total_tokens)
128+
# Do not set token data if it is 0
129+
input_tokens = input_tokens or None
130+
input_tokens_cached = input_tokens_cached or None
131+
output_tokens = output_tokens or None
132+
output_tokens_reasoning = output_tokens_reasoning or None
133+
total_tokens = total_tokens or None
134+
135+
record_token_usage(
136+
span,
137+
input_tokens=input_tokens,
138+
input_tokens_cached=input_tokens_cached,
139+
output_tokens=output_tokens,
140+
output_tokens_reasoning=output_tokens_reasoning,
141+
total_tokens=total_tokens,
142+
)
120143

121144

122145
def _new_chat_completion_common(f, *args, **kwargs):
@@ -163,9 +186,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
163186
SPANDATA.AI_RESPONSES,
164187
list(map(lambda x: x.message, res.choices)),
165188
)
166-
_calculate_chat_completion_usage(
167-
messages, res, span, None, integration.count_tokens
168-
)
189+
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
169190
span.__exit__(None, None, None)
170191
elif hasattr(res, "_iterator"):
171192
data_buf: list[list[str]] = [] # one for each choice
@@ -196,7 +217,7 @@ def new_iterator():
196217
set_data_normalized(
197218
span, SPANDATA.AI_RESPONSES, all_responses
198219
)
199-
_calculate_chat_completion_usage(
220+
_calculate_token_usage(
200221
messages,
201222
res,
202223
span,
@@ -229,7 +250,7 @@ async def new_iterator_async():
229250
set_data_normalized(
230251
span, SPANDATA.AI_RESPONSES, all_responses
231252
)
232-
_calculate_chat_completion_usage(
253+
_calculate_token_usage(
233254
messages,
234255
res,
235256
span,
@@ -346,22 +367,26 @@ def _new_embeddings_create_common(f, *args, **kwargs):
346367

347368
response = yield f, args, kwargs
348369

349-
prompt_tokens = 0
370+
input_tokens = 0
350371
total_tokens = 0
351372
if hasattr(response, "usage"):
352373
if hasattr(response.usage, "prompt_tokens") and isinstance(
353374
response.usage.prompt_tokens, int
354375
):
355-
prompt_tokens = response.usage.prompt_tokens
376+
input_tokens = response.usage.prompt_tokens
356377
if hasattr(response.usage, "total_tokens") and isinstance(
357378
response.usage.total_tokens, int
358379
):
359380
total_tokens = response.usage.total_tokens
360381

361-
if prompt_tokens == 0:
362-
prompt_tokens = integration.count_tokens(kwargs["input"] or "")
382+
if input_tokens == 0:
383+
input_tokens = integration.count_tokens(kwargs["input"] or "")
363384

364-
record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)
385+
record_token_usage(
386+
span,
387+
input_tokens=input_tokens,
388+
total_tokens=total_tokens or input_tokens,
389+
)
365390

366391
return response
367392

@@ -464,7 +489,7 @@ def _new_responses_create_common(f, *args, **kwargs):
464489
SPANDATA.GEN_AI_RESPONSE_TEXT,
465490
json.dumps([item.to_dict() for item in res.output]),
466491
)
467-
_calculate_chat_completion_usage([], res, span, None, integration.count_tokens)
492+
_calculate_token_usage([], res, span, None, integration.count_tokens)
468493
span.__exit__(None, None, None)
469494

470495
else:

0 commit comments

Comments
 (0)