Skip to content

Commit 1b6ed45

Browse files
committed
add data for tokens from langchain
1 parent 9669277 commit 1b6ed45

File tree

1 file changed

+213
-11
lines changed

1 file changed

+213
-11
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 213 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# To avoid double collecting tokens, we do *not* measure
4848
# token counts for models for which we have an explicit integration
4949
NO_COLLECT_TOKEN_MODELS = [
50-
"openai-chat",
50+
# "openai-chat",
5151
"anthropic-chat",
5252
"cohere-chat",
5353
"huggingface_endpoint",
@@ -62,7 +62,7 @@ class LangchainIntegration(Integration):
6262
max_spans = 1024
6363

6464
def __init__(
65-
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None
65+
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name="cl100k_base"
6666
):
6767
# type: (LangchainIntegration, bool, int, Optional[str]) -> None
6868
self.include_prompts = include_prompts
@@ -134,6 +134,47 @@ def _normalize_langchain_message(self, message):
134134
parsed.update(message.additional_kwargs)
135135
return parsed
136136

137+
def _extract_token_usage(self, token_usage):
138+
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
139+
"""Extract input, output, and total tokens from various token usage formats.
140+
141+
Based on LangChain's callback pattern for token tracking:
142+
https://python.langchain.com/docs/how_to/llm_token_usage_tracking/
143+
"""
144+
if not token_usage:
145+
return None, None, None
146+
147+
input_tokens = None
148+
output_tokens = None
149+
total_tokens = None
150+
151+
if hasattr(token_usage, "get"):
152+
# Dictionary format - common in LangChain callbacks
153+
input_tokens = token_usage.get("prompt_tokens") or token_usage.get(
154+
"input_tokens"
155+
)
156+
output_tokens = token_usage.get("completion_tokens") or token_usage.get(
157+
"output_tokens"
158+
)
159+
total_tokens = token_usage.get("total_tokens")
160+
else:
161+
# Object format - used by some model providers
162+
input_tokens = getattr(token_usage, "prompt_tokens", None) or getattr(
163+
token_usage, "input_tokens", None
164+
)
165+
output_tokens = getattr(token_usage, "completion_tokens", None) or getattr(
166+
token_usage, "output_tokens", None
167+
)
168+
total_tokens = getattr(token_usage, "total_tokens", None)
169+
170+
# LangChain's OpenAI callback uses these specific field names
171+
if input_tokens is None and hasattr(token_usage, "get"):
172+
input_tokens = token_usage.get("prompt_tokens")
173+
if output_tokens is None and hasattr(token_usage, "get"):
174+
output_tokens = token_usage.get("completion_tokens")
175+
176+
return input_tokens, output_tokens, total_tokens
177+
137178
def _create_span(self, run_id, parent_id, **kwargs):
138179
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
139180

@@ -250,16 +291,119 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
250291
message.content
251292
) + self.count_tokens(message.type)
252293

294+
def on_chat_model_end(self, response, *, run_id, **kwargs):
295+
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
296+
"""Run when Chat Model ends running."""
297+
with capture_internal_exceptions():
298+
if not run_id:
299+
return
300+
301+
# Extract token usage following LangChain's callback pattern
302+
# Reference: https://python.langchain.com/docs/how_to/llm_token_usage_tracking/
303+
token_usage = None
304+
305+
# Debug: Log the response structure to understand what's available
306+
logger.debug(
307+
"LangChain response structure: llm_output=%s, has_usage=%s",
308+
bool(response.llm_output),
309+
hasattr(response, "usage"),
310+
)
311+
312+
if response.llm_output and "token_usage" in response.llm_output:
313+
token_usage = response.llm_output["token_usage"]
314+
logger.debug("Found token_usage in llm_output dict: %s", token_usage)
315+
elif response.llm_output and hasattr(response.llm_output, "token_usage"):
316+
token_usage = response.llm_output.token_usage
317+
logger.debug(
318+
"Found token_usage as llm_output attribute: %s", token_usage
319+
)
320+
elif hasattr(response, "usage"):
321+
# Some models might have usage directly on the response (OpenAI-style)
322+
token_usage = response.usage
323+
logger.debug("Found usage on response: %s", token_usage)
324+
elif hasattr(response, "token_usage"):
325+
# Direct token_usage attribute
326+
token_usage = response.token_usage
327+
logger.debug("Found token_usage on response: %s", token_usage)
328+
else:
329+
logger.debug(
330+
"No token usage found in response, will use manual counting"
331+
)
332+
333+
span_data = self.span_map[run_id]
334+
if not span_data:
335+
return
336+
337+
if should_send_default_pii() and self.include_prompts:
338+
set_data_normalized(
339+
span_data.span,
340+
SPANDATA.GEN_AI_RESPONSE_TEXT,
341+
[[x.text for x in list_] for list_ in response.generations],
342+
)
343+
344+
if not span_data.no_collect_tokens:
345+
if token_usage:
346+
input_tokens, output_tokens, total_tokens = (
347+
self._extract_token_usage(token_usage)
348+
)
349+
# Log token usage for debugging (will be removed in production)
350+
logger.debug(
351+
"LangChain token usage found: input=%s, output=%s, total=%s",
352+
input_tokens,
353+
output_tokens,
354+
total_tokens,
355+
)
356+
record_token_usage(
357+
span_data.span,
358+
input_tokens=input_tokens,
359+
output_tokens=output_tokens,
360+
total_tokens=total_tokens,
361+
)
362+
else:
363+
# Fallback to manual token counting when no usage info is available
364+
logger.debug(
365+
"No token usage from LangChain, using manual count: input=%s, output=%s",
366+
span_data.num_prompt_tokens,
367+
span_data.num_completion_tokens,
368+
)
369+
record_token_usage(
370+
span_data.span,
371+
input_tokens=(
372+
span_data.num_prompt_tokens
373+
if span_data.num_prompt_tokens > 0
374+
else None
375+
),
376+
output_tokens=(
377+
span_data.num_completion_tokens
378+
if span_data.num_completion_tokens > 0
379+
else None
380+
),
381+
)
382+
383+
self._exit_span(span_data, run_id)
384+
253385
def on_llm_new_token(self, token, *, run_id, **kwargs):
254386
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
255-
"""Run on new LLM token. Only available when streaming is enabled."""
387+
"""Run on new LLM token. Only available when streaming is enabled.
388+
389+
Note: LangChain documentation mentions that streaming token counts
390+
may not be fully supported for all models. This provides a fallback
391+
for manual counting during streaming.
392+
"""
256393
with capture_internal_exceptions():
257394
if not run_id or run_id not in self.span_map:
258395
return
259396
span_data = self.span_map[run_id]
260397
if not span_data or span_data.no_collect_tokens:
261398
return
262-
span_data.num_completion_tokens += self.count_tokens(token)
399+
# Count tokens for each streaming chunk
400+
token_count = self.count_tokens(token)
401+
span_data.num_completion_tokens += token_count
402+
logger.debug(
403+
"Streaming token count updated: +%s (total: %s)",
404+
token_count,
405+
span_data.num_completion_tokens,
406+
)
263407

264408
def on_llm_end(self, response, *, run_id, **kwargs):
265409
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -268,10 +412,38 @@ def on_llm_end(self, response, *, run_id, **kwargs):
268412
if not run_id:
269413
return
270414

271-
token_usage = (
272-
response.llm_output.get("token_usage") if response.llm_output else None
415+
# Extract token usage following LangChain's callback pattern
416+
# Reference: https://python.langchain.com/docs/how_to/llm_token_usage_tracking/
417+
token_usage = None
418+
419+
# Debug: Log the response structure to understand what's available
420+
logger.debug(
421+
"LangChain response structure: llm_output=%s, has_usage=%s",
422+
bool(response.llm_output),
423+
hasattr(response, "usage"),
273424
)
274425

426+
if response.llm_output and "token_usage" in response.llm_output:
427+
token_usage = response.llm_output["token_usage"]
428+
logger.debug("Found token_usage in llm_output dict: %s", token_usage)
429+
elif response.llm_output and hasattr(response.llm_output, "token_usage"):
430+
token_usage = response.llm_output.token_usage
431+
logger.debug(
432+
"Found token_usage as llm_output attribute: %s", token_usage
433+
)
434+
elif hasattr(response, "usage"):
435+
# Some models might have usage directly on the response (OpenAI-style)
436+
token_usage = response.usage
437+
logger.debug("Found usage on response: %s", token_usage)
438+
elif hasattr(response, "token_usage"):
439+
# Direct token_usage attribute
440+
token_usage = response.token_usage
441+
logger.debug("Found token_usage on response: %s", token_usage)
442+
else:
443+
logger.debug(
444+
"No token usage found in response, will use manual counting"
445+
)
446+
275447
span_data = self.span_map[run_id]
276448
if not span_data:
277449
return
@@ -285,17 +457,41 @@ def on_llm_end(self, response, *, run_id, **kwargs):
285457

286458
if not span_data.no_collect_tokens:
287459
if token_usage:
460+
input_tokens, output_tokens, total_tokens = (
461+
self._extract_token_usage(token_usage)
462+
)
463+
# Log token usage for debugging (will be removed in production)
464+
logger.debug(
465+
"LangChain token usage found: input=%s, output=%s, total=%s",
466+
input_tokens,
467+
output_tokens,
468+
total_tokens,
469+
)
288470
record_token_usage(
289471
span_data.span,
290-
input_tokens=token_usage.get("prompt_tokens"),
291-
output_tokens=token_usage.get("completion_tokens"),
292-
total_tokens=token_usage.get("total_tokens"),
472+
input_tokens=input_tokens,
473+
output_tokens=output_tokens,
474+
total_tokens=total_tokens,
293475
)
294476
else:
477+
# Fallback to manual token counting when no usage info is available
478+
logger.debug(
479+
"No token usage from LangChain, using manual count: input=%s, output=%s",
480+
span_data.num_prompt_tokens,
481+
span_data.num_completion_tokens,
482+
)
295483
record_token_usage(
296484
span_data.span,
297-
input_tokens=span_data.num_prompt_tokens,
298-
output_tokens=span_data.num_completion_tokens,
485+
input_tokens=(
486+
span_data.num_prompt_tokens
487+
if span_data.num_prompt_tokens > 0
488+
else None
489+
),
490+
output_tokens=(
491+
span_data.num_completion_tokens
492+
if span_data.num_completion_tokens > 0
493+
else None
494+
),
299495
)
300496

301497
self._exit_span(span_data, run_id)
@@ -306,6 +502,12 @@ def on_llm_error(self, error, *, run_id, **kwargs):
306502
with capture_internal_exceptions():
307503
self._handle_error(run_id, error)
308504

505+
def on_chat_model_error(self, error, *, run_id, **kwargs):
506+
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
507+
"""Run when Chat Model errors."""
508+
with capture_internal_exceptions():
509+
self._handle_error(run_id, error)
510+
309511
def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
310512
# type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
311513
"""Run when chain starts running."""

0 commit comments

Comments
 (0)