@@ -156,35 +156,19 @@ def _calculate_token_usage(
156156 )
157157
158158
159- def _new_chat_completion_common (f , * args , ** kwargs ):
160- # type: (Any, Any, Any) -> Any
161- integration = sentry_sdk .get_client ().get_integration (OpenAIIntegration )
162- if integration is None :
163- return f (* args , ** kwargs )
159+ def _set_request_data (span , kwargs , integration ):
160+ # type: (Span, dict[str, Any], Integration) -> None
161+ messages = kwargs .get ("messages" )
162+ if (
163+ messages is not None
164+ and should_send_default_pii ()
165+ and integration .include_prompts
166+ ):
167+ set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MESSAGES , messages )
164168
165- if "messages" not in kwargs :
166- # invalid call (in all versions of openai), let it return error
167- return f (* args , ** kwargs )
168-
169- try :
170- iter (kwargs ["messages" ])
171- except TypeError :
172- # invalid call (in all versions), messages must be iterable
173- return f (* args , ** kwargs )
174-
175- kwargs ["messages" ] = list (kwargs ["messages" ])
176- messages = kwargs ["messages" ]
169+ # Common attributes
177170 model = kwargs .get ("model" )
178171 streaming = kwargs .get ("stream" )
179-
180- span = sentry_sdk .start_span (
181- op = consts .OP .GEN_AI_CHAT ,
182- name = f"chat { model } " ,
183- origin = OpenAIIntegration .origin ,
184- )
185- span .__enter__ ()
186-
187- # Common attributes
188172 set_data_normalized (span , SPANDATA .GEN_AI_SYSTEM , "openai" )
189173 set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MODEL , model )
190174 set_data_normalized (span , SPANDATA .GEN_AI_OPERATION_NAME , "chat" )
@@ -209,106 +193,131 @@ def _new_chat_completion_common(f, *args, **kwargs):
209193 if top_p is not None :
210194 set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_TOP_P , top_p )
211195
212- res = yield f , args , kwargs
213196
214- with capture_internal_exceptions ():
215- if should_send_default_pii () and integration .include_prompts :
216- set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MESSAGES , messages )
197+ def _set_response_data (span , res , kwargs , integration ):
198+ # type: (Span, Any, dict[str, Any], Integration) -> None
199+ if hasattr (res , "model" ):
200+ set_data_normalized (span , SPANDATA .GEN_AI_RESPONSE_MODEL , res .model )
201+
202+ messages = kwargs .get ("messages" , [])
217203
218- if hasattr (res , "model" ):
219- set_data_normalized (span , SPANDATA .GEN_AI_RESPONSE_MODEL , res .model )
204+ if hasattr (res , "choices" ):
205+ if should_send_default_pii () and integration .include_prompts :
206+ response_text = [choice .message .dict () for choice in res .choices ]
207+ if len (response_text ) > 0 :
208+ set_data_normalized (
209+ span ,
210+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
211+ safe_serialize (response_text ),
212+ )
220213
221- if hasattr (res , "choices" ):
222- if should_send_default_pii () and integration .include_prompts :
223- response_text = [choice .message .dict () for choice in res .choices ]
224- if len (response_text ) > 0 :
225- set_data_normalized (
214+ _calculate_token_usage (messages , res , span , None , integration .count_tokens )
215+
216+ elif hasattr (res , "_iterator" ):
217+ data_buf : list [list [str ]] = [] # one for each choice
218+
219+ old_iterator = res ._iterator
220+
221+ def new_iterator ():
222+ # type: () -> Iterator[ChatCompletionChunk]
223+ with capture_internal_exceptions ():
224+ for x in old_iterator :
225+ if hasattr (x , "choices" ):
226+ choice_index = 0
227+ for choice in x .choices :
228+ if hasattr (choice , "delta" ) and hasattr (
229+ choice .delta , "content"
230+ ):
231+ content = choice .delta .content
232+ if len (data_buf ) <= choice_index :
233+ data_buf .append ([])
234+ data_buf [choice_index ].append (content or "" )
235+ choice_index += 1
236+ yield x
237+ if len (data_buf ) > 0 :
238+ all_responses = list (map (lambda chunk : "" .join (chunk ), data_buf ))
239+ if should_send_default_pii () and integration .include_prompts :
240+ set_data_normalized (
241+ span , SPANDATA .GEN_AI_RESPONSE_TEXT , all_responses
242+ )
243+ _calculate_token_usage (
244+ messages ,
245+ res ,
226246 span ,
227- SPANDATA . GEN_AI_RESPONSE_TEXT ,
228- safe_serialize ( response_text ) ,
247+ all_responses ,
248+ integration . count_tokens ,
229249 )
230-
231- _calculate_token_usage (messages , res , span , None , integration .count_tokens )
232250 span .__exit__ (None , None , None )
233- elif hasattr (res , "_iterator" ):
234- data_buf : list [list [str ]] = [] # one for each choice
235-
236- old_iterator = res ._iterator
237-
238- def new_iterator ():
239- # type: () -> Iterator[ChatCompletionChunk]
240- with capture_internal_exceptions ():
241- for x in old_iterator :
242- if hasattr (x , "choices" ):
243- choice_index = 0
244- for choice in x .choices :
245- if hasattr (choice , "delta" ) and hasattr (
246- choice .delta , "content"
247- ):
248- content = choice .delta .content
249- if len (data_buf ) <= choice_index :
250- data_buf .append ([])
251- data_buf [choice_index ].append (content or "" )
252- choice_index += 1
253- yield x
254- if len (data_buf ) > 0 :
255- all_responses = list (
256- map (lambda chunk : "" .join (chunk ), data_buf )
257- )
258- if should_send_default_pii () and integration .include_prompts :
259- set_data_normalized (
260- span , SPANDATA .GEN_AI_RESPONSE_TEXT , all_responses
261- )
262- _calculate_token_usage (
263- messages ,
264- res ,
265- span ,
266- all_responses ,
267- integration .count_tokens ,
268- )
269- span .__exit__ (None , None , None )
270-
271- async def new_iterator_async ():
272- # type: () -> AsyncIterator[ChatCompletionChunk]
273- with capture_internal_exceptions ():
274- async for x in old_iterator :
275- if hasattr (x , "choices" ):
276- choice_index = 0
277- for choice in x .choices :
278- if hasattr (choice , "delta" ) and hasattr (
279- choice .delta , "content"
280- ):
281- content = choice .delta .content
282- if len (data_buf ) <= choice_index :
283- data_buf .append ([])
284- data_buf [choice_index ].append (content or "" )
285- choice_index += 1
286- yield x
287- if len (data_buf ) > 0 :
288- all_responses = list (
289- map (lambda chunk : "" .join (chunk ), data_buf )
290- )
291- if should_send_default_pii () and integration .include_prompts :
292- set_data_normalized (
293- span , SPANDATA .GEN_AI_RESPONSE_TEXT , all_responses
294- )
295- _calculate_token_usage (
296- messages ,
297- res ,
298- span ,
299- all_responses ,
300- integration .count_tokens ,
301- )
302- span .__exit__ (None , None , None )
303251
304- if str (type (res ._iterator )) == "<class 'async_generator'>" :
305- res ._iterator = new_iterator_async ()
306- else :
307- res ._iterator = new_iterator ()
252+ async def new_iterator_async ():
253+ # type: () -> AsyncIterator[ChatCompletionChunk]
254+ with capture_internal_exceptions ():
255+ async for x in old_iterator :
256+ if hasattr (x , "choices" ):
257+ choice_index = 0
258+ for choice in x .choices :
259+ if hasattr (choice , "delta" ) and hasattr (
260+ choice .delta , "content"
261+ ):
262+ content = choice .delta .content
263+ if len (data_buf ) <= choice_index :
264+ data_buf .append ([])
265+ data_buf [choice_index ].append (content or "" )
266+ choice_index += 1
267+ yield x
268+ if len (data_buf ) > 0 :
269+ all_responses = list (map (lambda chunk : "" .join (chunk ), data_buf ))
270+ if should_send_default_pii () and integration .include_prompts :
271+ set_data_normalized (
272+ span , SPANDATA .GEN_AI_RESPONSE_TEXT , all_responses
273+ )
274+ _calculate_token_usage (
275+ messages ,
276+ res ,
277+ span ,
278+ all_responses ,
279+ integration .count_tokens ,
280+ )
281+ span .__exit__ (None , None , None )
308282
283+ if str (type (res ._iterator )) == "<class 'async_generator'>" :
284+ res ._iterator = new_iterator_async ()
309285 else :
310- set_data_normalized (span , "unknown_response" , True )
311- span .__exit__ (None , None , None )
286+ res ._iterator = new_iterator ()
287+
288+ else :
289+ set_data_normalized (span , "unknown_response" , True )
290+
291+
292+ def _new_chat_completion_common (f , * args , ** kwargs ):
293+ # type: (Any, Any, Any) -> Any
294+ integration = sentry_sdk .get_client ().get_integration (OpenAIIntegration )
295+ if integration is None :
296+ return f (* args , ** kwargs )
297+
298+ if "messages" not in kwargs :
299+ # invalid call (in all versions of openai), let it return error
300+ return f (* args , ** kwargs )
301+
302+ try :
303+ iter (kwargs ["messages" ])
304+ except TypeError :
305+ # invalid call (in all versions), messages must be iterable
306+ return f (* args , ** kwargs )
307+
308+ model = kwargs .get ("model" )
309+
310+ with sentry_sdk .start_span (
311+ op = consts .OP .GEN_AI_CHAT ,
312+ name = f"chat { model } " ,
313+ origin = OpenAIIntegration .origin ,
314+ ) as span :
315+ _set_request_data (span , kwargs , integration )
316+
317+ res = yield f , args , kwargs
318+
319+ _set_response_data (span , res , kwargs , integration )
320+
312321 return res
313322
314323
0 commit comments