4848# token counts for models for which we have an explicit integration
4949NO_COLLECT_TOKEN_MODELS = [
5050 # "openai-chat",
51- "anthropic-chat" ,
51+ # "anthropic-chat",
5252 "cohere-chat" ,
5353 "huggingface_endpoint" ,
5454]
@@ -61,13 +61,10 @@ class LangchainIntegration(Integration):
6161 # The most number of spans (e.g., LLM calls) that can be processed at the same time.
6262 max_spans = 1024
6363
64- def __init__ (
65- self , include_prompts = True , max_spans = 1024 , tiktoken_encoding_name = "cl100k_base"
66- ):
67- # type: (LangchainIntegration, bool, int, Optional[str]) -> None
64+ def __init__ (self , include_prompts = True , max_spans = 1024 ):
65+ # type: (LangchainIntegration, bool, int) -> None
6866 self .include_prompts = include_prompts
6967 self .max_spans = max_spans
70- self .tiktoken_encoding_name = tiktoken_encoding_name
7168
7269 @staticmethod
7370 def setup_once ():
@@ -77,8 +74,6 @@ def setup_once():
7774
7875class WatchedSpan :
7976 span = None # type: Span
80- num_completion_tokens = 0 # type: int
81- num_prompt_tokens = 0 # type: int
8277 no_collect_tokens = False # type: bool
8378 children = [] # type: List[WatchedSpan]
8479 is_pipeline = False # type: bool
@@ -91,24 +86,12 @@ def __init__(self, span):
9186class SentryLangchainCallback (BaseCallbackHandler ): # type: ignore[misc]
9287 """Base callback handler that can be used to handle callbacks from langchain."""
9388
94- def __init__ (self , max_span_map_size , include_prompts , tiktoken_encoding_name = None ):
95- # type: (int, bool, Optional[str] ) -> None
89+ def __init__ (self , max_span_map_size , include_prompts ):
90+ # type: (int, bool) -> None
9691 self .span_map = OrderedDict () # type: OrderedDict[UUID, WatchedSpan]
9792 self .max_span_map_size = max_span_map_size
9893 self .include_prompts = include_prompts
9994
100- self .tiktoken_encoding = None
101- if tiktoken_encoding_name is not None :
102- import tiktoken # type: ignore
103-
104- self .tiktoken_encoding = tiktoken .get_encoding (tiktoken_encoding_name )
105-
106- def count_tokens (self , s ):
107- # type: (str) -> int
108- if self .tiktoken_encoding is not None :
109- return len (self .tiktoken_encoding .encode_ordinary (s ))
110- return 0
111-
11295 def gc_span_map (self ):
11396 # type: () -> None
11497
@@ -163,9 +146,70 @@ def _extract_token_usage(self, token_usage):
163146
164147 # LangChain's OpenAI callback uses these specific field names
165148 if input_tokens is None and hasattr (token_usage , "get" ):
166- input_tokens = token_usage .get ("prompt_tokens" )
149+ input_tokens = token_usage .get ("prompt_tokens" ) or token_usage .get (
150+ "input_tokens"
151+ )
167152 if output_tokens is None and hasattr (token_usage , "get" ):
168- output_tokens = token_usage .get ("completion_tokens" )
153+ output_tokens = token_usage .get ("completion_tokens" ) or token_usage .get (
154+ "output_tokens"
155+ )
156+ if total_tokens is None and hasattr (token_usage , "get" ):
157+ total_tokens = token_usage .get ("total_tokens" )
158+
159+ return input_tokens , output_tokens , total_tokens
160+
161+ def _extract_token_usage_from_generations (self , generations ):
162+ # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
163+ """Extract token usage from response.generations structure."""
164+ if not generations :
165+ return None , None , None
166+
167+ total_input = 0
168+ total_output = 0
169+ total_total = 0
170+ found = False
171+
172+ for gen_list in generations :
173+ for gen in gen_list :
174+ usage_metadata = None
175+ if (
176+ hasattr (gen , "message" )
177+ and getattr (gen , "message" , None ) is not None
178+ and hasattr (gen .message , "usage_metadata" )
179+ ):
180+ usage_metadata = getattr (gen .message , "usage_metadata" , None )
181+ if usage_metadata is None and hasattr (gen , "usage_metadata" ):
182+ usage_metadata = getattr (gen , "usage_metadata" , None )
183+ if usage_metadata :
184+ input_tokens , output_tokens , total_tokens = (
185+ self ._extract_token_usage_from_response (usage_metadata )
186+ )
187+ if any ([input_tokens , output_tokens , total_tokens ]):
188+ found = True
189+ total_input += int (input_tokens )
190+ total_output += int (output_tokens )
191+ total_total += int (total_tokens )
192+
193+ if not found :
194+ return None , None , None
195+
196+ return (
197+ total_input if total_input > 0 else None ,
198+ total_output if total_output > 0 else None ,
199+ total_total if total_total > 0 else None ,
200+ )
201+
202+ def _extract_token_usage_from_response (self , response ):
203+ # type: (Any) -> tuple[int, int, int]
204+ if response :
205+ if hasattr (response , "get" ):
206+ input_tokens = response .get ("input_tokens" , 0 )
207+ output_tokens = response .get ("output_tokens" , 0 )
208+ total_tokens = response .get ("total_tokens" , 0 )
209+ else :
210+ input_tokens = getattr (response , "input_tokens" , 0 )
211+ output_tokens = getattr (response , "output_tokens" , 0 )
212+ total_tokens = getattr (response , "total_tokens" , 0 )
169213
170214 return input_tokens , output_tokens , total_tokens
171215
@@ -278,12 +322,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
278322 for k , v in DATA_FIELDS .items ():
279323 if k in all_params :
280324 set_data_normalized (span , v , all_params [k ])
281- if not watched_span .no_collect_tokens :
282- for list_ in messages :
283- for message in list_ :
284- self .span_map [run_id ].num_prompt_tokens += self .count_tokens (
285- message .content
286- ) + self .count_tokens (message .type )
325+ # no manual token counting
287326
288327 def on_chat_model_end (self , response , * , run_id , ** kwargs ):
289328 # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -294,6 +333,7 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
294333
295334 token_usage = None
296335
336+ # Try multiple paths to extract token usage, prioritizing streaming-aware approaches
297337 if response .llm_output and "token_usage" in response .llm_output :
298338 token_usage = response .llm_output ["token_usage" ]
299339 elif response .llm_output and hasattr (response .llm_output , "token_usage" ):
@@ -302,6 +342,13 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
302342 token_usage = response .usage
303343 elif hasattr (response , "token_usage" ):
304344 token_usage = response .token_usage
345+ # Check for usage_metadata in llm_output (common in streaming responses)
346+ elif response .llm_output and "usage_metadata" in response .llm_output :
347+ token_usage = response .llm_output ["usage_metadata" ]
348+ elif response .llm_output and hasattr (response .llm_output , "usage_metadata" ):
349+ token_usage = response .llm_output .usage_metadata
350+ elif hasattr (response , "usage_metadata" ):
351+ token_usage = response .usage_metadata
305352
306353 span_data = self .span_map [run_id ]
307354 if not span_data :
@@ -319,42 +366,31 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
319366 input_tokens , output_tokens , total_tokens = (
320367 self ._extract_token_usage (token_usage )
321368 )
369+ else :
370+ input_tokens , output_tokens , total_tokens = (
371+ self ._extract_token_usage_from_generations (response .generations )
372+ )
322373
374+ if (
375+ input_tokens is not None
376+ or output_tokens is not None
377+ or total_tokens is not None
378+ ):
323379 record_token_usage (
324380 span_data .span ,
325381 input_tokens = input_tokens ,
326382 output_tokens = output_tokens ,
327383 total_tokens = total_tokens ,
328384 )
329- else :
330- record_token_usage (
331- span_data .span ,
332- input_tokens = (
333- span_data .num_prompt_tokens
334- if span_data .num_prompt_tokens > 0
335- else None
336- ),
337- output_tokens = (
338- span_data .num_completion_tokens
339- if span_data .num_completion_tokens > 0
340- else None
341- ),
342- )
343385
344386 self ._exit_span (span_data , run_id )
345387
346388 def on_llm_new_token (self , token , * , run_id , ** kwargs ):
347389 # type: (SentryLangchainCallback, str, UUID, Any) -> Any
348390 """Run on new LLM token. Only available when streaming is enabled."""
391+ # no manual token counting
349392 with capture_internal_exceptions ():
350- if not run_id or run_id not in self .span_map :
351- return
352- span_data = self .span_map [run_id ]
353- if not span_data or span_data .no_collect_tokens :
354- return
355- # Count tokens for each streaming chunk
356- token_count = self .count_tokens (token )
357- span_data .num_completion_tokens += token_count
393+ return
358394
359395 def on_llm_end (self , response , * , run_id , ** kwargs ):
360396 # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
@@ -392,26 +428,22 @@ def on_llm_end(self, response, *, run_id, **kwargs):
392428 input_tokens , output_tokens , total_tokens = (
393429 self ._extract_token_usage (token_usage )
394430 )
431+ else :
432+ input_tokens , output_tokens , total_tokens = (
433+ self ._extract_token_usage_from_generations (response .generations )
434+ )
435+
436+ if (
437+ input_tokens is not None
438+ or output_tokens is not None
439+ or total_tokens is not None
440+ ):
395441 record_token_usage (
396442 span_data .span ,
397443 input_tokens = input_tokens ,
398444 output_tokens = output_tokens ,
399445 total_tokens = total_tokens ,
400446 )
401- else :
402- record_token_usage (
403- span_data .span ,
404- input_tokens = (
405- span_data .num_prompt_tokens
406- if span_data .num_prompt_tokens > 0
407- else None
408- ),
409- output_tokens = (
410- span_data .num_completion_tokens
411- if span_data .num_completion_tokens > 0
412- else None
413- ),
414- )
415447
416448 self ._exit_span (span_data , run_id )
417449
@@ -602,7 +634,6 @@ def new_configure(
602634 sentry_handler = SentryLangchainCallback (
603635 integration .max_spans ,
604636 integration .include_prompts ,
605- integration .tiktoken_encoding_name ,
606637 )
607638 if isinstance (local_callbacks , BaseCallbackManager ):
608639 local_callbacks = local_callbacks .copy ()
0 commit comments