@@ -70,48 +70,75 @@ def _capture_exception(exc):
7070 sentry_sdk .capture_event (event , hint = hint )
7171
7272
73- def _calculate_chat_completion_usage (
73+ def _get_usage (usage , names ):
74+ # type: (Any, List[str]) -> int
75+ for name in names :
76+ if hasattr (usage , name ) and isinstance (getattr (usage , name ), int ):
77+ return getattr (usage , name )
78+ return 0
79+
80+
81+ def _calculate_token_usage (
7482 messages , response , span , streaming_message_responses , count_tokens
7583):
7684 # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
77- completion_tokens = 0 # type: Optional[int]
78- prompt_tokens = 0 # type: Optional[int]
85+ input_tokens = 0 # type: Optional[int]
86+ input_tokens_cached = 0 # type: Optional[int]
87+ output_tokens = 0 # type: Optional[int]
88+ output_tokens_reasoning = 0 # type: Optional[int]
7989 total_tokens = 0 # type: Optional[int]
90+
8091 if hasattr (response , "usage" ):
81- if hasattr (response .usage , "completion_tokens" ) and isinstance (
82- response .usage .completion_tokens , int
83- ):
84- completion_tokens = response .usage .completion_tokens
85- if hasattr (response .usage , "prompt_tokens" ) and isinstance (
86- response .usage .prompt_tokens , int
87- ):
88- prompt_tokens = response .usage .prompt_tokens
89- if hasattr (response .usage , "total_tokens" ) and isinstance (
90- response .usage .total_tokens , int
91- ):
92- total_tokens = response .usage .total_tokens
92+ input_tokens = _get_usage (response .usage , ["input_tokens" , "prompt_tokens" ])
93+ if hasattr (response .usage , "input_tokens_details" ):
94+ input_tokens_cached = _get_usage (
95+ response .usage .input_tokens_details , ["cached_tokens" ]
96+ )
9397
94- if prompt_tokens == 0 :
98+ output_tokens = _get_usage (
99+ response .usage , ["output_tokens" , "completion_tokens" ]
100+ )
101+ if hasattr (response .usage , "output_tokens_details" ):
102+ output_tokens_reasoning = _get_usage (
103+ response .usage .output_tokens_details , ["reasoning_tokens" ]
104+ )
105+
106+ total_tokens = _get_usage (response .usage , ["total_tokens" ])
107+
108+ # Manually count tokens
109+ # TODO: check for responses API
110+ if input_tokens == 0 :
95111 for message in messages :
96112 if "content" in message :
97- prompt_tokens += count_tokens (message ["content" ])
113+ input_tokens += count_tokens (message ["content" ])
98114
99- if completion_tokens == 0 :
115+ # TODO: check for responses API
116+ if output_tokens == 0 :
100117 if streaming_message_responses is not None :
101118 for message in streaming_message_responses :
102- completion_tokens += count_tokens (message )
119+ output_tokens += count_tokens (message )
103120 elif hasattr (response , "choices" ):
104121 for choice in response .choices :
105122 if hasattr (choice , "message" ):
106- completion_tokens += count_tokens (choice .message )
107-
108- if prompt_tokens == 0 :
109- prompt_tokens = None
110- if completion_tokens == 0 :
111- completion_tokens = None
112- if total_tokens == 0 :
113- total_tokens = None
114- record_token_usage (span , prompt_tokens , completion_tokens , total_tokens )
123+ output_tokens += count_tokens (choice .message )
124+
125+ # Do not set token data if it is 0
126+ input_tokens = None if input_tokens == 0 else input_tokens
127+ input_tokens_cached = None if input_tokens_cached == 0 else input_tokens_cached
128+ output_tokens = None if output_tokens == 0 else output_tokens
129+ output_tokens_reasoning = (
130+ None if output_tokens_reasoning == 0 else output_tokens_reasoning
131+ )
132+ total_tokens = None if total_tokens == 0 else total_tokens
133+
134+ record_token_usage (
135+ span ,
136+ input_tokens = input_tokens ,
137+ input_tokens_cached = input_tokens_cached ,
138+ output_tokens = output_tokens ,
139+ output_tokens_reasoning = output_tokens_reasoning ,
140+ total_tokens = total_tokens ,
141+ )
115142
116143
117144def _new_chat_completion_common (f , * args , ** kwargs ):
@@ -158,9 +185,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
158185 SPANDATA .AI_RESPONSES ,
159186 list (map (lambda x : x .message , res .choices )),
160187 )
161- _calculate_chat_completion_usage (
162- messages , res , span , None , integration .count_tokens
163- )
188+ _calculate_token_usage (messages , res , span , None , integration .count_tokens )
164189 span .__exit__ (None , None , None )
165190 elif hasattr (res , "_iterator" ):
166191 data_buf : list [list [str ]] = [] # one for each choice
@@ -191,7 +216,7 @@ def new_iterator():
191216 set_data_normalized (
192217 span , SPANDATA .AI_RESPONSES , all_responses
193218 )
194- _calculate_chat_completion_usage (
219+ _calculate_token_usage (
195220 messages ,
196221 res ,
197222 span ,
@@ -224,7 +249,7 @@ async def new_iterator_async():
224249 set_data_normalized (
225250 span , SPANDATA .AI_RESPONSES , all_responses
226251 )
227- _calculate_chat_completion_usage (
252+ _calculate_token_usage (
228253 messages ,
229254 res ,
230255 span ,
0 commit comments