Skip to content

Commit 2f04a23

Browse files
committed
organize code
1 parent 5b11b7f commit 2f04a23

File tree

1 file changed

+126
-117
lines changed

1 file changed

+126
-117
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 126 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -156,35 +156,19 @@ def _calculate_token_usage(
156156
)
157157

158158

159-
def _new_chat_completion_common(f, *args, **kwargs):
160-
# type: (Any, Any, Any) -> Any
161-
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
162-
if integration is None:
163-
return f(*args, **kwargs)
159+
def _set_request_data(span, kwargs, integration):
160+
# type: (Span, dict[str, Any], Integration) -> None
161+
messages = kwargs.get("messages")
162+
if (
163+
messages is not None
164+
and should_send_default_pii()
165+
and integration.include_prompts
166+
):
167+
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages)
164168

165-
if "messages" not in kwargs:
166-
# invalid call (in all versions of openai), let it return error
167-
return f(*args, **kwargs)
168-
169-
try:
170-
iter(kwargs["messages"])
171-
except TypeError:
172-
# invalid call (in all versions), messages must be iterable
173-
return f(*args, **kwargs)
174-
175-
kwargs["messages"] = list(kwargs["messages"])
176-
messages = kwargs["messages"]
169+
# Common attributes
177170
model = kwargs.get("model")
178171
streaming = kwargs.get("stream")
179-
180-
span = sentry_sdk.start_span(
181-
op=consts.OP.GEN_AI_CHAT,
182-
name=f"chat {model}",
183-
origin=OpenAIIntegration.origin,
184-
)
185-
span.__enter__()
186-
187-
# Common attributes
188172
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
189173
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
190174
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
@@ -209,106 +193,131 @@ def _new_chat_completion_common(f, *args, **kwargs):
209193
if top_p is not None:
210194
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)
211195

212-
res = yield f, args, kwargs
213196

214-
with capture_internal_exceptions():
215-
if should_send_default_pii() and integration.include_prompts:
216-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages)
197+
def _set_response_data(span, res, kwargs, integration):
198+
# type: (Span, Any, dict[str, Any], Integration) -> None
199+
if hasattr(res, "model"):
200+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, res.model)
201+
202+
messages = kwargs.get("messages", [])
217203

218-
if hasattr(res, "model"):
219-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, res.model)
204+
if hasattr(res, "choices"):
205+
if should_send_default_pii() and integration.include_prompts:
206+
response_text = [choice.message.dict() for choice in res.choices]
207+
if len(response_text) > 0:
208+
set_data_normalized(
209+
span,
210+
SPANDATA.GEN_AI_RESPONSE_TEXT,
211+
safe_serialize(response_text),
212+
)
220213

221-
if hasattr(res, "choices"):
222-
if should_send_default_pii() and integration.include_prompts:
223-
response_text = [choice.message.dict() for choice in res.choices]
224-
if len(response_text) > 0:
225-
set_data_normalized(
214+
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
215+
216+
elif hasattr(res, "_iterator"):
217+
data_buf: list[list[str]] = [] # one for each choice
218+
219+
old_iterator = res._iterator
220+
221+
def new_iterator():
222+
# type: () -> Iterator[ChatCompletionChunk]
223+
with capture_internal_exceptions():
224+
for x in old_iterator:
225+
if hasattr(x, "choices"):
226+
choice_index = 0
227+
for choice in x.choices:
228+
if hasattr(choice, "delta") and hasattr(
229+
choice.delta, "content"
230+
):
231+
content = choice.delta.content
232+
if len(data_buf) <= choice_index:
233+
data_buf.append([])
234+
data_buf[choice_index].append(content or "")
235+
choice_index += 1
236+
yield x
237+
if len(data_buf) > 0:
238+
all_responses = list(map(lambda chunk: "".join(chunk), data_buf))
239+
if should_send_default_pii() and integration.include_prompts:
240+
set_data_normalized(
241+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
242+
)
243+
_calculate_token_usage(
244+
messages,
245+
res,
226246
span,
227-
SPANDATA.GEN_AI_RESPONSE_TEXT,
228-
safe_serialize(response_text),
247+
all_responses,
248+
integration.count_tokens,
229249
)
230-
231-
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
232250
span.__exit__(None, None, None)
233-
elif hasattr(res, "_iterator"):
234-
data_buf: list[list[str]] = [] # one for each choice
235-
236-
old_iterator = res._iterator
237-
238-
def new_iterator():
239-
# type: () -> Iterator[ChatCompletionChunk]
240-
with capture_internal_exceptions():
241-
for x in old_iterator:
242-
if hasattr(x, "choices"):
243-
choice_index = 0
244-
for choice in x.choices:
245-
if hasattr(choice, "delta") and hasattr(
246-
choice.delta, "content"
247-
):
248-
content = choice.delta.content
249-
if len(data_buf) <= choice_index:
250-
data_buf.append([])
251-
data_buf[choice_index].append(content or "")
252-
choice_index += 1
253-
yield x
254-
if len(data_buf) > 0:
255-
all_responses = list(
256-
map(lambda chunk: "".join(chunk), data_buf)
257-
)
258-
if should_send_default_pii() and integration.include_prompts:
259-
set_data_normalized(
260-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
261-
)
262-
_calculate_token_usage(
263-
messages,
264-
res,
265-
span,
266-
all_responses,
267-
integration.count_tokens,
268-
)
269-
span.__exit__(None, None, None)
270-
271-
async def new_iterator_async():
272-
# type: () -> AsyncIterator[ChatCompletionChunk]
273-
with capture_internal_exceptions():
274-
async for x in old_iterator:
275-
if hasattr(x, "choices"):
276-
choice_index = 0
277-
for choice in x.choices:
278-
if hasattr(choice, "delta") and hasattr(
279-
choice.delta, "content"
280-
):
281-
content = choice.delta.content
282-
if len(data_buf) <= choice_index:
283-
data_buf.append([])
284-
data_buf[choice_index].append(content or "")
285-
choice_index += 1
286-
yield x
287-
if len(data_buf) > 0:
288-
all_responses = list(
289-
map(lambda chunk: "".join(chunk), data_buf)
290-
)
291-
if should_send_default_pii() and integration.include_prompts:
292-
set_data_normalized(
293-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
294-
)
295-
_calculate_token_usage(
296-
messages,
297-
res,
298-
span,
299-
all_responses,
300-
integration.count_tokens,
301-
)
302-
span.__exit__(None, None, None)
303251

304-
if str(type(res._iterator)) == "<class 'async_generator'>":
305-
res._iterator = new_iterator_async()
306-
else:
307-
res._iterator = new_iterator()
252+
async def new_iterator_async():
253+
# type: () -> AsyncIterator[ChatCompletionChunk]
254+
with capture_internal_exceptions():
255+
async for x in old_iterator:
256+
if hasattr(x, "choices"):
257+
choice_index = 0
258+
for choice in x.choices:
259+
if hasattr(choice, "delta") and hasattr(
260+
choice.delta, "content"
261+
):
262+
content = choice.delta.content
263+
if len(data_buf) <= choice_index:
264+
data_buf.append([])
265+
data_buf[choice_index].append(content or "")
266+
choice_index += 1
267+
yield x
268+
if len(data_buf) > 0:
269+
all_responses = list(map(lambda chunk: "".join(chunk), data_buf))
270+
if should_send_default_pii() and integration.include_prompts:
271+
set_data_normalized(
272+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
273+
)
274+
_calculate_token_usage(
275+
messages,
276+
res,
277+
span,
278+
all_responses,
279+
integration.count_tokens,
280+
)
281+
span.__exit__(None, None, None)
308282

283+
if str(type(res._iterator)) == "<class 'async_generator'>":
284+
res._iterator = new_iterator_async()
309285
else:
310-
set_data_normalized(span, "unknown_response", True)
311-
span.__exit__(None, None, None)
286+
res._iterator = new_iterator()
287+
288+
else:
289+
set_data_normalized(span, "unknown_response", True)
290+
291+
292+
def _new_chat_completion_common(f, *args, **kwargs):
293+
# type: (Any, Any, Any) -> Any
294+
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
295+
if integration is None:
296+
return f(*args, **kwargs)
297+
298+
if "messages" not in kwargs:
299+
# invalid call (in all versions of openai), let it return error
300+
return f(*args, **kwargs)
301+
302+
try:
303+
iter(kwargs["messages"])
304+
except TypeError:
305+
# invalid call (in all versions), messages must be iterable
306+
return f(*args, **kwargs)
307+
308+
model = kwargs.get("model")
309+
310+
with sentry_sdk.start_span(
311+
op=consts.OP.GEN_AI_CHAT,
312+
name=f"chat {model}",
313+
origin=OpenAIIntegration.origin,
314+
) as span:
315+
_set_request_data(span, kwargs, integration)
316+
317+
res = yield f, args, kwargs
318+
319+
_set_response_data(span, res, kwargs, integration)
320+
312321
return res
313322

314323

0 commit comments

Comments
 (0)