@@ -73,34 +73,49 @@ def _capture_exception(exc):
7373 sentry_sdk .capture_event (event , hint = hint )
7474
7575
76- def _calculate_chat_completion_usage (
76+ def _get_usage (usage , names ):
77+ # type: (Any, List[str]) -> int
78+ for name in names :
79+ if hasattr (usage , name ) and isinstance (getattr (usage , name ), int ):
80+ return getattr (usage , name )
81+ return 0
82+
83+
84+ def _calculate_token_usage (
7785 messages , response , span , streaming_message_responses , count_tokens
7886):
7987 # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
8088 input_tokens = 0 # type: Optional[int]
89+ input_tokens_cached = 0 # type: Optional[int]
8190 output_tokens = 0 # type: Optional[int]
91+ output_tokens_reasoning = 0 # type: Optional[int]
8292 total_tokens = 0 # type: Optional[int]
93+
8394 if hasattr (response , "usage" ):
84- if hasattr (response .usage , "input_tokens" ) and isinstance (
85- response .usage .input_tokens , int
86- ):
87- input_tokens = response .usage .input_tokens
95+ input_tokens = _get_usage (response .usage , ["input_tokens" , "prompt_tokens" ])
96+ if hasattr (response .usage , "input_tokens_details" ):
97+ input_tokens_cached = _get_usage (
98+ response .usage .input_tokens_details , ["cached_tokens" ]
99+ )
88100
89- if hasattr (response .usage , "output_tokens" ) and isinstance (
90- response .usage .output_tokens , int
91- ):
92- output_tokens = response .usage .output_tokens
101+ output_tokens = _get_usage (
102+ response .usage , ["output_tokens" , "completion_tokens" ]
103+ )
104+ if hasattr (response .usage , "output_tokens_details" ):
105+ output_tokens_reasoning = _get_usage (
106+ response .usage .output_tokens_details , ["reasoning_tokens" ]
107+ )
93108
94- if hasattr (response .usage , "total_tokens" ) and isinstance (
95- response .usage .total_tokens , int
96- ):
97- total_tokens = response .usage .total_tokens
109+ total_tokens = _get_usage (response .usage , ["total_tokens" ])
98110
111+ # Manually count tokens
112+ # TODO: when implementing responses API, check for responses API
99113 if input_tokens == 0 :
100114 for message in messages :
101115 if "content" in message :
102116 input_tokens += count_tokens (message ["content" ])
103117
118+ # TODO: when implementing responses API, check for responses API
104119 if output_tokens == 0 :
105120 if streaming_message_responses is not None :
106121 for message in streaming_message_responses :
@@ -110,13 +125,21 @@ def _calculate_chat_completion_usage(
110125 if hasattr (choice , "message" ):
111126 output_tokens += count_tokens (choice .message )
112127
113- if input_tokens == 0 :
114- input_tokens = None
115- if output_tokens == 0 :
116- output_tokens = None
117- if total_tokens == 0 :
118- total_tokens = None
119- record_token_usage (span , input_tokens , output_tokens , total_tokens )
128+ # Do not set token data if it is 0
129+ input_tokens = input_tokens or None
130+ input_tokens_cached = input_tokens_cached or None
131+ output_tokens = output_tokens or None
132+ output_tokens_reasoning = output_tokens_reasoning or None
133+ total_tokens = total_tokens or None
134+
135+ record_token_usage (
136+ span ,
137+ input_tokens = input_tokens ,
138+ input_tokens_cached = input_tokens_cached ,
139+ output_tokens = output_tokens ,
140+ output_tokens_reasoning = output_tokens_reasoning ,
141+ total_tokens = total_tokens ,
142+ )
120143
121144
122145def _new_chat_completion_common (f , * args , ** kwargs ):
@@ -163,9 +186,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
163186 SPANDATA .AI_RESPONSES ,
164187 list (map (lambda x : x .message , res .choices )),
165188 )
166- _calculate_chat_completion_usage (
167- messages , res , span , None , integration .count_tokens
168- )
189+ _calculate_token_usage (messages , res , span , None , integration .count_tokens )
169190 span .__exit__ (None , None , None )
170191 elif hasattr (res , "_iterator" ):
171192 data_buf : list [list [str ]] = [] # one for each choice
@@ -196,7 +217,7 @@ def new_iterator():
196217 set_data_normalized (
197218 span , SPANDATA .AI_RESPONSES , all_responses
198219 )
199- _calculate_chat_completion_usage (
220+ _calculate_token_usage (
200221 messages ,
201222 res ,
202223 span ,
@@ -229,7 +250,7 @@ async def new_iterator_async():
229250 set_data_normalized (
230251 span , SPANDATA .AI_RESPONSES , all_responses
231252 )
232- _calculate_chat_completion_usage (
253+ _calculate_token_usage (
233254 messages ,
234255 res ,
235256 span ,
@@ -346,22 +367,26 @@ def _new_embeddings_create_common(f, *args, **kwargs):
346367
347368 response = yield f , args , kwargs
348369
349- prompt_tokens = 0
370+ input_tokens = 0
350371 total_tokens = 0
351372 if hasattr (response , "usage" ):
352373 if hasattr (response .usage , "prompt_tokens" ) and isinstance (
353374 response .usage .prompt_tokens , int
354375 ):
355- prompt_tokens = response .usage .prompt_tokens
376+ input_tokens = response .usage .prompt_tokens
356377 if hasattr (response .usage , "total_tokens" ) and isinstance (
357378 response .usage .total_tokens , int
358379 ):
359380 total_tokens = response .usage .total_tokens
360381
361- if prompt_tokens == 0 :
362- prompt_tokens = integration .count_tokens (kwargs ["input" ] or "" )
382+ if input_tokens == 0 :
383+ input_tokens = integration .count_tokens (kwargs ["input" ] or "" )
363384
364- record_token_usage (span , prompt_tokens , None , total_tokens or prompt_tokens )
385+ record_token_usage (
386+ span ,
387+ input_tokens = input_tokens ,
388+ total_tokens = total_tokens or input_tokens ,
389+ )
365390
366391 return response
367392
@@ -464,7 +489,7 @@ def _new_responses_create_common(f, *args, **kwargs):
464489 SPANDATA .GEN_AI_RESPONSE_TEXT ,
465490 json .dumps ([item .to_dict () for item in res .output ]),
466491 )
467- _calculate_chat_completion_usage ([], res , span , None , integration .count_tokens )
492+ _calculate_token_usage ([], res , span , None , integration .count_tokens )
468493 span .__exit__ (None , None , None )
469494
470495 else :
0 commit comments