Skip to content

Commit 2acd8fe

Browse files
committed
add token extraction for streams & openai
1 parent 01a30c1 commit 2acd8fe

File tree

1 file changed

+98
-67
lines changed

1 file changed

+98
-67
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 98 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# token counts for models for which we have an explicit integration
4949
NO_COLLECT_TOKEN_MODELS = [
5050
# "openai-chat",
51-
"anthropic-chat",
51+
# "anthropic-chat",
5252
"cohere-chat",
5353
"huggingface_endpoint",
5454
]
@@ -61,13 +61,10 @@ class LangchainIntegration(Integration):
6161
# The most number of spans (e.g., LLM calls) that can be processed at the same time.
6262
max_spans = 1024
6363

64-
def __init__(
65-
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name="cl100k_base"
66-
):
67-
# type: (LangchainIntegration, bool, int, Optional[str]) -> None
64+
def __init__(self, include_prompts=True, max_spans=1024):
65+
# type: (LangchainIntegration, bool, int) -> None
6866
self.include_prompts = include_prompts
6967
self.max_spans = max_spans
70-
self.tiktoken_encoding_name = tiktoken_encoding_name
7168

7269
@staticmethod
7370
def setup_once():
@@ -77,8 +74,6 @@ def setup_once():
7774

7875
class WatchedSpan:
7976
span = None # type: Span
80-
num_completion_tokens = 0 # type: int
81-
num_prompt_tokens = 0 # type: int
8277
no_collect_tokens = False # type: bool
8378
children = [] # type: List[WatchedSpan]
8479
is_pipeline = False # type: bool
@@ -91,24 +86,12 @@ def __init__(self, span):
9186
class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
9287
"""Base callback handler that can be used to handle callbacks from langchain."""
9388

94-
def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None):
95-
# type: (int, bool, Optional[str]) -> None
89+
def __init__(self, max_span_map_size, include_prompts):
90+
# type: (int, bool) -> None
9691
self.span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]
9792
self.max_span_map_size = max_span_map_size
9893
self.include_prompts = include_prompts
9994

100-
self.tiktoken_encoding = None
101-
if tiktoken_encoding_name is not None:
102-
import tiktoken # type: ignore
103-
104-
self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)
105-
106-
def count_tokens(self, s):
107-
# type: (str) -> int
108-
if self.tiktoken_encoding is not None:
109-
return len(self.tiktoken_encoding.encode_ordinary(s))
110-
return 0
111-
11295
def gc_span_map(self):
11396
# type: () -> None
11497

@@ -163,9 +146,70 @@ def _extract_token_usage(self, token_usage):
163146

164147
# LangChain's OpenAI callback uses these specific field names
165148
if input_tokens is None and hasattr(token_usage, "get"):
166-
input_tokens = token_usage.get("prompt_tokens")
149+
input_tokens = token_usage.get("prompt_tokens") or token_usage.get(
150+
"input_tokens"
151+
)
167152
if output_tokens is None and hasattr(token_usage, "get"):
168-
output_tokens = token_usage.get("completion_tokens")
153+
output_tokens = token_usage.get("completion_tokens") or token_usage.get(
154+
"output_tokens"
155+
)
156+
if total_tokens is None and hasattr(token_usage, "get"):
157+
total_tokens = token_usage.get("total_tokens")
158+
159+
return input_tokens, output_tokens, total_tokens
160+
161+
def _extract_token_usage_from_generations(self, generations):
162+
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
163+
"""Extract token usage from response.generations structure."""
164+
if not generations:
165+
return None, None, None
166+
167+
total_input = 0
168+
total_output = 0
169+
total_total = 0
170+
found = False
171+
172+
for gen_list in generations:
173+
for gen in gen_list:
174+
usage_metadata = None
175+
if (
176+
hasattr(gen, "message")
177+
and getattr(gen, "message", None) is not None
178+
and hasattr(gen.message, "usage_metadata")
179+
):
180+
usage_metadata = getattr(gen.message, "usage_metadata", None)
181+
if usage_metadata is None and hasattr(gen, "usage_metadata"):
182+
usage_metadata = getattr(gen, "usage_metadata", None)
183+
if usage_metadata:
184+
input_tokens, output_tokens, total_tokens = (
185+
self._extract_token_usage_from_response(usage_metadata)
186+
)
187+
if any([input_tokens, output_tokens, total_tokens]):
188+
found = True
189+
total_input += int(input_tokens)
190+
total_output += int(output_tokens)
191+
total_total += int(total_tokens)
192+
193+
if not found:
194+
return None, None, None
195+
196+
return (
197+
total_input if total_input > 0 else None,
198+
total_output if total_output > 0 else None,
199+
total_total if total_total > 0 else None,
200+
)
201+
202+
def _extract_token_usage_from_response(self, response):
203+
# type: (Any) -> tuple[int, int, int]
204+
if response:
205+
if hasattr(response, "get"):
206+
input_tokens = response.get("input_tokens", 0)
207+
output_tokens = response.get("output_tokens", 0)
208+
total_tokens = response.get("total_tokens", 0)
209+
else:
210+
input_tokens = getattr(response, "input_tokens", 0)
211+
output_tokens = getattr(response, "output_tokens", 0)
212+
total_tokens = getattr(response, "total_tokens", 0)
169213

170214
return input_tokens, output_tokens, total_tokens
171215

@@ -278,12 +322,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
278322
for k, v in DATA_FIELDS.items():
279323
if k in all_params:
280324
set_data_normalized(span, v, all_params[k])
281-
if not watched_span.no_collect_tokens:
282-
for list_ in messages:
283-
for message in list_:
284-
self.span_map[run_id].num_prompt_tokens += self.count_tokens(
285-
message.content
286-
) + self.count_tokens(message.type)
325+
# no manual token counting
287326

288327
def on_chat_model_end(self, response, *, run_id, **kwargs):
289328
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -294,6 +333,7 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
294333

295334
token_usage = None
296335

336+
# Try multiple paths to extract token usage, prioritizing streaming-aware approaches
297337
if response.llm_output and "token_usage" in response.llm_output:
298338
token_usage = response.llm_output["token_usage"]
299339
elif response.llm_output and hasattr(response.llm_output, "token_usage"):
@@ -302,6 +342,13 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
302342
token_usage = response.usage
303343
elif hasattr(response, "token_usage"):
304344
token_usage = response.token_usage
345+
# Check for usage_metadata in llm_output (common in streaming responses)
346+
elif response.llm_output and "usage_metadata" in response.llm_output:
347+
token_usage = response.llm_output["usage_metadata"]
348+
elif response.llm_output and hasattr(response.llm_output, "usage_metadata"):
349+
token_usage = response.llm_output.usage_metadata
350+
elif hasattr(response, "usage_metadata"):
351+
token_usage = response.usage_metadata
305352

306353
span_data = self.span_map[run_id]
307354
if not span_data:
@@ -319,42 +366,31 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
319366
input_tokens, output_tokens, total_tokens = (
320367
self._extract_token_usage(token_usage)
321368
)
369+
else:
370+
input_tokens, output_tokens, total_tokens = (
371+
self._extract_token_usage_from_generations(response.generations)
372+
)
322373

374+
if (
375+
input_tokens is not None
376+
or output_tokens is not None
377+
or total_tokens is not None
378+
):
323379
record_token_usage(
324380
span_data.span,
325381
input_tokens=input_tokens,
326382
output_tokens=output_tokens,
327383
total_tokens=total_tokens,
328384
)
329-
else:
330-
record_token_usage(
331-
span_data.span,
332-
input_tokens=(
333-
span_data.num_prompt_tokens
334-
if span_data.num_prompt_tokens > 0
335-
else None
336-
),
337-
output_tokens=(
338-
span_data.num_completion_tokens
339-
if span_data.num_completion_tokens > 0
340-
else None
341-
),
342-
)
343385

344386
self._exit_span(span_data, run_id)
345387

346388
def on_llm_new_token(self, token, *, run_id, **kwargs):
347389
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
348390
"""Run on new LLM token. Only available when streaming is enabled."""
391+
# no manual token counting
349392
with capture_internal_exceptions():
350-
if not run_id or run_id not in self.span_map:
351-
return
352-
span_data = self.span_map[run_id]
353-
if not span_data or span_data.no_collect_tokens:
354-
return
355-
# Count tokens for each streaming chunk
356-
token_count = self.count_tokens(token)
357-
span_data.num_completion_tokens += token_count
393+
return
358394

359395
def on_llm_end(self, response, *, run_id, **kwargs):
360396
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -392,26 +428,22 @@ def on_llm_end(self, response, *, run_id, **kwargs):
392428
input_tokens, output_tokens, total_tokens = (
393429
self._extract_token_usage(token_usage)
394430
)
431+
else:
432+
input_tokens, output_tokens, total_tokens = (
433+
self._extract_token_usage_from_generations(response.generations)
434+
)
435+
436+
if (
437+
input_tokens is not None
438+
or output_tokens is not None
439+
or total_tokens is not None
440+
):
395441
record_token_usage(
396442
span_data.span,
397443
input_tokens=input_tokens,
398444
output_tokens=output_tokens,
399445
total_tokens=total_tokens,
400446
)
401-
else:
402-
record_token_usage(
403-
span_data.span,
404-
input_tokens=(
405-
span_data.num_prompt_tokens
406-
if span_data.num_prompt_tokens > 0
407-
else None
408-
),
409-
output_tokens=(
410-
span_data.num_completion_tokens
411-
if span_data.num_completion_tokens > 0
412-
else None
413-
),
414-
)
415447

416448
self._exit_span(span_data, run_id)
417449

@@ -602,7 +634,6 @@ def new_configure(
602634
sentry_handler = SentryLangchainCallback(
603635
integration.max_spans,
604636
integration.include_prompts,
605-
integration.tiktoken_encoding_name,
606637
)
607638
if isinstance(local_callbacks, BaseCallbackManager):
608639
local_callbacks = local_callbacks.copy()

0 commit comments

Comments
 (0)