Skip to content

Commit 6b2f5de

Browse files
committed
refactor embeddings
1 parent 2f04a23 commit 6b2f5de

File tree

1 file changed

+39
-65
lines changed

1 file changed

+39
-65
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,18 @@ def _calculate_token_usage(
156156
)
157157

158158

159-
def _set_request_data(span, kwargs, integration):
160-
# type: (Span, dict[str, Any], Integration) -> None
159+
def _set_request_data(span, kwargs, operation, integration):
160+
# type: (Span, dict[str, Any], str, Integration) -> None
161161
messages = kwargs.get("messages")
162+
if messages is None:
163+
messages = kwargs.get("input")
164+
165+
if isinstance(messages, str):
166+
messages = [messages]
167+
162168
if (
163169
messages is not None
170+
and len(messages) > 0
164171
and should_send_default_pii()
165172
and integration.include_prompts
166173
):
@@ -171,7 +178,7 @@ def _set_request_data(span, kwargs, integration):
171178
streaming = kwargs.get("stream")
172179
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
173180
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
174-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
181+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
175182
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
176183

177184
# Optional attributes
@@ -194,29 +201,32 @@ def _set_request_data(span, kwargs, integration):
194201
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)
195202

196203

197-
def _set_response_data(span, res, kwargs, integration):
204+
def _set_response_data(span, response, kwargs, integration):
198205
# type: (Span, Any, dict[str, Any], Integration) -> None
199-
if hasattr(res, "model"):
200-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, res.model)
206+
if hasattr(response, "model"):
207+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)
208+
209+
messages = kwargs.get("messages")
210+
if messages is None:
211+
messages = kwargs.get("input")
201212

202-
messages = kwargs.get("messages", [])
213+
if isinstance(messages, str):
214+
messages = [messages]
203215

204-
if hasattr(res, "choices"):
216+
if hasattr(response, "choices"):
205217
if should_send_default_pii() and integration.include_prompts:
206-
response_text = [choice.message.dict() for choice in res.choices]
218+
response_text = [choice.message.dict() for choice in response.choices]
207219
if len(response_text) > 0:
208220
set_data_normalized(
209221
span,
210222
SPANDATA.GEN_AI_RESPONSE_TEXT,
211223
safe_serialize(response_text),
212224
)
213225

214-
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
215-
216-
elif hasattr(res, "_iterator"):
226+
elif hasattr(response, "_iterator"):
217227
data_buf: list[list[str]] = [] # one for each choice
218228

219-
old_iterator = res._iterator
229+
old_iterator = response._iterator
220230

221231
def new_iterator():
222232
# type: () -> Iterator[ChatCompletionChunk]
@@ -242,7 +252,7 @@ def new_iterator():
242252
)
243253
_calculate_token_usage(
244254
messages,
245-
res,
255+
response,
246256
span,
247257
all_responses,
248258
integration.count_tokens,
@@ -273,20 +283,19 @@ async def new_iterator_async():
273283
)
274284
_calculate_token_usage(
275285
messages,
276-
res,
286+
response,
277287
span,
278288
all_responses,
279289
integration.count_tokens,
280290
)
281291
span.__exit__(None, None, None)
282292

283-
if str(type(res._iterator)) == "<class 'async_generator'>":
284-
res._iterator = new_iterator_async()
293+
if str(type(response._iterator)) == "<class 'async_generator'>":
294+
response._iterator = new_iterator_async()
285295
else:
286-
res._iterator = new_iterator()
296+
response._iterator = new_iterator()
287297

288-
else:
289-
set_data_normalized(span, "unknown_response", True)
298+
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
290299

291300

292301
def _new_chat_completion_common(f, *args, **kwargs):
@@ -306,19 +315,20 @@ def _new_chat_completion_common(f, *args, **kwargs):
306315
return f(*args, **kwargs)
307316

308317
model = kwargs.get("model")
318+
operation = "chat"
309319

310320
with sentry_sdk.start_span(
311321
op=consts.OP.GEN_AI_CHAT,
312-
name=f"chat {model}",
322+
name=f"{operation} {model}",
313323
origin=OpenAIIntegration.origin,
314324
) as span:
315-
_set_request_data(span, kwargs, integration)
325+
_set_request_data(span, kwargs, operation, integration)
316326

317-
res = yield f, args, kwargs
327+
response = yield f, args, kwargs
318328

319-
_set_response_data(span, res, kwargs, integration)
329+
_set_response_data(span, response, kwargs, integration)
320330

321-
return res
331+
return response
322332

323333

324334
def _wrap_chat_completion_create(f):
@@ -398,54 +408,18 @@ def _new_embeddings_create_common(f, *args, **kwargs):
398408
return f(*args, **kwargs)
399409

400410
model = kwargs.get("model")
411+
operation = "embeddings"
401412

402413
with sentry_sdk.start_span(
403414
op=consts.OP.GEN_AI_EMBEDDINGS,
404-
name=f"embeddings {model}",
415+
name=f"{operation} {model}",
405416
origin=OpenAIIntegration.origin,
406417
) as span:
407-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
408-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
409-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
410-
411-
if "input" in kwargs and (
412-
should_send_default_pii() and integration.include_prompts
413-
):
414-
if isinstance(kwargs["input"], str):
415-
set_data_normalized(
416-
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, [kwargs["input"]]
417-
)
418-
elif (
419-
isinstance(kwargs["input"], list)
420-
and len(kwargs["input"]) > 0
421-
and isinstance(kwargs["input"][0], str)
422-
):
423-
set_data_normalized(
424-
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, kwargs["input"]
425-
)
418+
_set_request_data(span, kwargs, operation, integration)
426419

427420
response = yield f, args, kwargs
428421

429-
input_tokens = 0
430-
total_tokens = 0
431-
if hasattr(response, "usage"):
432-
if hasattr(response.usage, "prompt_tokens") and isinstance(
433-
response.usage.prompt_tokens, int
434-
):
435-
input_tokens = response.usage.prompt_tokens
436-
if hasattr(response.usage, "total_tokens") and isinstance(
437-
response.usage.total_tokens, int
438-
):
439-
total_tokens = response.usage.total_tokens
440-
441-
if input_tokens == 0:
442-
input_tokens = integration.count_tokens(kwargs["input"] or "")
443-
444-
record_token_usage(
445-
span,
446-
input_tokens=input_tokens,
447-
total_tokens=total_tokens or input_tokens,
448-
)
422+
_set_response_data(span, response, kwargs, integration)
449423

450424
return response
451425

0 commit comments

Comments
 (0)