Skip to content

Commit 3ad5e50

Browse files
committed
cleanup
1 parent c80a413 commit 3ad5e50

File tree

1 file changed

+35
-48
lines changed

1 file changed

+35
-48
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def _calculate_token_usage(
156156

157157

158158
# TODO: rename to _set_input_data and _set_output_data
159-
def _set_request_data(span, kwargs, operation, integration):
159+
def _set_input_data(span, kwargs, operation, integration):
160160
# type: (Span, dict[str, Any], str, Integration) -> None
161+
# Input messages (the prompt or data sent to the model)
161162
messages = kwargs.get("messages")
162163
if messages is None:
163164
messages = kwargs.get("input")
@@ -173,51 +174,34 @@ def _set_request_data(span, kwargs, operation, integration):
173174
):
174175
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages)
175176

176-
# TODO: make mapping and loop over kwargs to set attributes
177-
# Common attributes
178-
model = kwargs.get("model")
177+
# Input attributes: Common
179178
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
180-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
181179
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
182180

183-
# Optional attributes
184-
streaming = kwargs.get("stream")
185-
if streaming is not None:
186-
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
187-
188-
max_tokens = kwargs.get("max_tokens")
189-
if max_tokens is not None:
190-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)
191-
192-
presence_penalty = kwargs.get("presence_penalty")
193-
if presence_penalty is not None:
194-
set_data_normalized(
195-
span, SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty
196-
)
197-
198-
frequency_penalty = kwargs.get("frequency_penalty")
199-
if frequency_penalty is not None:
200-
set_data_normalized(
201-
span, SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty
202-
)
203-
204-
temperature = kwargs.get("temperature")
205-
if temperature is not None:
206-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)
207-
208-
top_p = kwargs.get("top_p")
209-
if top_p is not None:
210-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)
211-
212-
# Tools
181+
# Input attributes: Optional
182+
kwargs_keys_to_attributes = {
183+
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
184+
"stream": SPANDATA.AI_STREAMING,
185+
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
186+
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
187+
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
188+
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
189+
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
190+
}
191+
for key, attribute in kwargs_keys_to_attributes.items():
192+
value = kwargs.get(key)
193+
if value is not None:
194+
set_data_normalized(span, attribute, value)
195+
196+
# Input attributes: Tools
213197
tools = kwargs.get("tools", [])
214198
if tools is not None and len(tools) > 0:
215199
set_data_normalized(
216200
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
217201
)
218202

219203

220-
def _set_response_data(span, response, kwargs, integration, finish_span=True):
204+
def _set_output_data(span, response, kwargs, integration, finish_span=True):
221205
# type: (Span, Any, dict[str, Any], Integration, bool) -> None
222206
if hasattr(response, "model"):
223207
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)
@@ -238,7 +222,8 @@ def _set_response_data(span, response, kwargs, integration, finish_span=True):
238222
SPANDATA.GEN_AI_RESPONSE_TEXT,
239223
safe_serialize(response_text),
240224
)
241-
span.__exit__(None, None, None)
225+
if finish_span:
226+
span.__exit__(None, None, None)
242227

243228
elif hasattr(response, "output"):
244229
if should_send_default_pii() and integration.include_prompts:
@@ -249,6 +234,8 @@ def _set_response_data(span, response, kwargs, integration, finish_span=True):
249234
SPANDATA.GEN_AI_RESPONSE_TEXT,
250235
safe_serialize(response_text),
251236
)
237+
if finish_span:
238+
span.__exit__(None, None, None)
252239

253240
elif hasattr(response, "_iterator"):
254241
data_buf: list[list[str]] = [] # one for each choice
@@ -284,7 +271,8 @@ def new_iterator():
284271
all_responses,
285272
integration.count_tokens,
286273
)
287-
span.__exit__(None, None, None)
274+
if finish_span:
275+
span.__exit__(None, None, None)
288276

289277
async def new_iterator_async():
290278
# type: () -> AsyncIterator[ChatCompletionChunk]
@@ -315,14 +303,14 @@ async def new_iterator_async():
315303
all_responses,
316304
integration.count_tokens,
317305
)
318-
span.__exit__(None, None, None)
306+
if finish_span:
307+
span.__exit__(None, None, None)
319308

320309
if str(type(response._iterator)) == "<class 'async_generator'>":
321310
response._iterator = new_iterator_async()
322311
else:
323312
response._iterator = new_iterator()
324313
else:
325-
set_data_normalized(span, "unknown_response", True)
326314
if finish_span:
327315
span.__exit__(None, None, None)
328316

@@ -355,11 +343,11 @@ def _new_chat_completion_common(f, *args, **kwargs):
355343
)
356344
span.__enter__()
357345

358-
_set_request_data(span, kwargs, operation, integration)
346+
_set_input_data(span, kwargs, operation, integration)
359347

360348
response = yield f, args, kwargs
361349

362-
_set_response_data(span, response, kwargs, integration)
350+
_set_output_data(span, response, kwargs, integration, finish_span=True)
363351

364352
return response
365353

@@ -448,11 +436,11 @@ def _new_embeddings_create_common(f, *args, **kwargs):
448436
name=f"{operation} {model}",
449437
origin=OpenAIIntegration.origin,
450438
) as span:
451-
_set_request_data(span, kwargs, operation, integration)
439+
_set_input_data(span, kwargs, operation, integration)
452440

453441
response = yield f, args, kwargs
454442

455-
_set_response_data(span, response, kwargs, integration, finish_span=False)
443+
_set_output_data(span, response, kwargs, integration, finish_span=False)
456444

457445
return response
458446

@@ -540,13 +528,12 @@ def _new_responses_create_common(f, *args, **kwargs):
540528
origin=OpenAIIntegration.origin,
541529
)
542530
span.__enter__()
543-
_set_request_data(span, kwargs, operation, integration)
544531

545-
response = yield f, args, kwargs
532+
_set_input_data(span, kwargs, operation, integration)
546533

547-
_set_response_data(span, response, kwargs, integration)
534+
response = yield f, args, kwargs
548535

549-
span.__exit__(None, None, None)
536+
_set_output_data(span, response, kwargs, integration, finish_span=True)
550537

551538
return response
552539

0 commit comments

Comments
 (0)