@@ -156,11 +156,18 @@ def _calculate_token_usage(
156156 )
157157
158158
159- def _set_request_data (span , kwargs , integration ):
160- # type: (Span, dict[str, Any], Integration) -> None
159+ def _set_request_data (span , kwargs , operation , integration ):
160+ # type: (Span, dict[str, Any], str, Integration) -> None
161161 messages = kwargs .get ("messages" )
162+ if messages is None :
163+ messages = kwargs .get ("input" )
164+
165+ if isinstance (messages , str ):
166+ messages = [messages ]
167+
162168 if (
163169 messages is not None
170+ and len (messages ) > 0
164171 and should_send_default_pii ()
165172 and integration .include_prompts
166173 ):
@@ -171,7 +178,7 @@ def _set_request_data(span, kwargs, integration):
171178 streaming = kwargs .get ("stream" )
172179 set_data_normalized (span , SPANDATA .GEN_AI_SYSTEM , "openai" )
173180 set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MODEL , model )
174- set_data_normalized (span , SPANDATA .GEN_AI_OPERATION_NAME , "chat" )
181+ set_data_normalized (span , SPANDATA .GEN_AI_OPERATION_NAME , operation )
175182 set_data_normalized (span , SPANDATA .AI_STREAMING , streaming )
176183
177184 # Optional attributes
@@ -194,29 +201,32 @@ def _set_request_data(span, kwargs, integration):
194201 set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_TOP_P , top_p )
195202
196203
197- def _set_response_data (span , res , kwargs , integration ):
204+ def _set_response_data (span , response , kwargs , integration ):
198205 # type: (Span, Any, dict[str, Any], Integration) -> None
199- if hasattr (res , "model" ):
200- set_data_normalized (span , SPANDATA .GEN_AI_RESPONSE_MODEL , res .model )
206+ if hasattr (response , "model" ):
207+ set_data_normalized (span , SPANDATA .GEN_AI_RESPONSE_MODEL , response .model )
208+
209+ messages = kwargs .get ("messages" )
210+ if messages is None :
211+ messages = kwargs .get ("input" )
201212
202- messages = kwargs .get ("messages" , [])
213+ if isinstance (messages , str ):
214+ messages = [messages ]
203215
204- if hasattr (res , "choices" ):
216+ if hasattr (response , "choices" ):
205217 if should_send_default_pii () and integration .include_prompts :
206- response_text = [choice .message .dict () for choice in res .choices ]
218+ response_text = [choice .message .dict () for choice in response .choices ]
207219 if len (response_text ) > 0 :
208220 set_data_normalized (
209221 span ,
210222 SPANDATA .GEN_AI_RESPONSE_TEXT ,
211223 safe_serialize (response_text ),
212224 )
213225
214- _calculate_token_usage (messages , res , span , None , integration .count_tokens )
215-
216- elif hasattr (res , "_iterator" ):
226+ elif hasattr (response , "_iterator" ):
217227 data_buf : list [list [str ]] = [] # one for each choice
218228
219- old_iterator = res ._iterator
229+ old_iterator = response ._iterator
220230
221231 def new_iterator ():
222232 # type: () -> Iterator[ChatCompletionChunk]
@@ -242,7 +252,7 @@ def new_iterator():
242252 )
243253 _calculate_token_usage (
244254 messages ,
245- res ,
255+ response ,
246256 span ,
247257 all_responses ,
248258 integration .count_tokens ,
@@ -273,20 +283,19 @@ async def new_iterator_async():
273283 )
274284 _calculate_token_usage (
275285 messages ,
276- res ,
286+ response ,
277287 span ,
278288 all_responses ,
279289 integration .count_tokens ,
280290 )
281291 span .__exit__ (None , None , None )
282292
283- if str (type (res ._iterator )) == "<class 'async_generator'>" :
284- res ._iterator = new_iterator_async ()
293+ if str (type (response ._iterator )) == "<class 'async_generator'>" :
294+ response ._iterator = new_iterator_async ()
285295 else :
286- res ._iterator = new_iterator ()
296+ response ._iterator = new_iterator ()
287297
288- else :
289- set_data_normalized (span , "unknown_response" , True )
298+ _calculate_token_usage (messages , response , span , None , integration .count_tokens )
290299
291300
292301def _new_chat_completion_common (f , * args , ** kwargs ):
@@ -306,19 +315,20 @@ def _new_chat_completion_common(f, *args, **kwargs):
306315 return f (* args , ** kwargs )
307316
308317 model = kwargs .get ("model" )
318+ operation = "chat"
309319
310320 with sentry_sdk .start_span (
311321 op = consts .OP .GEN_AI_CHAT ,
312- name = f"chat { model } " ,
322+ name = f"{ operation } { model } " ,
313323 origin = OpenAIIntegration .origin ,
314324 ) as span :
315- _set_request_data (span , kwargs , integration )
325+ _set_request_data (span , kwargs , operation , integration )
316326
317- res = yield f , args , kwargs
327+ response = yield f , args , kwargs
318328
319- _set_response_data (span , res , kwargs , integration )
329+ _set_response_data (span , response , kwargs , integration )
320330
321- return res
331+ return response
322332
323333
324334def _wrap_chat_completion_create (f ):
@@ -398,54 +408,18 @@ def _new_embeddings_create_common(f, *args, **kwargs):
398408 return f (* args , ** kwargs )
399409
400410 model = kwargs .get ("model" )
411+ operation = "embeddings"
401412
402413 with sentry_sdk .start_span (
403414 op = consts .OP .GEN_AI_EMBEDDINGS ,
404- name = f"embeddings { model } " ,
415+ name = f"{ operation } { model } " ,
405416 origin = OpenAIIntegration .origin ,
406417 ) as span :
407- set_data_normalized (span , SPANDATA .GEN_AI_SYSTEM , "openai" )
408- set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MODEL , model )
409- set_data_normalized (span , SPANDATA .GEN_AI_OPERATION_NAME , "embeddings" )
410-
411- if "input" in kwargs and (
412- should_send_default_pii () and integration .include_prompts
413- ):
414- if isinstance (kwargs ["input" ], str ):
415- set_data_normalized (
416- span , SPANDATA .GEN_AI_REQUEST_MESSAGES , [kwargs ["input" ]]
417- )
418- elif (
419- isinstance (kwargs ["input" ], list )
420- and len (kwargs ["input" ]) > 0
421- and isinstance (kwargs ["input" ][0 ], str )
422- ):
423- set_data_normalized (
424- span , SPANDATA .GEN_AI_REQUEST_MESSAGES , kwargs ["input" ]
425- )
418+ _set_request_data (span , kwargs , operation , integration )
426419
427420 response = yield f , args , kwargs
428421
429- input_tokens = 0
430- total_tokens = 0
431- if hasattr (response , "usage" ):
432- if hasattr (response .usage , "prompt_tokens" ) and isinstance (
433- response .usage .prompt_tokens , int
434- ):
435- input_tokens = response .usage .prompt_tokens
436- if hasattr (response .usage , "total_tokens" ) and isinstance (
437- response .usage .total_tokens , int
438- ):
439- total_tokens = response .usage .total_tokens
440-
441- if input_tokens == 0 :
442- input_tokens = integration .count_tokens (kwargs ["input" ] or "" )
443-
444- record_token_usage (
445- span ,
446- input_tokens = input_tokens ,
447- total_tokens = total_tokens or input_tokens ,
448- )
422+ _set_response_data (span , response , kwargs , integration )
449423
450424 return response
451425
0 commit comments