Skip to content
155 changes: 114 additions & 41 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,48 +219,54 @@ def _calculate_token_usage(
)


def _commmon_set_input_data(
def _set_responses_api_input_data(
span: "Span",
kwargs: "dict[str, Any]",
integration: "OpenAIIntegration",
) -> None:
# Input attributes: Common
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")

# Input attributes: Optional
kwargs_keys_to_attributes = {
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
}
for key, attribute in kwargs_keys_to_attributes.items():
value = kwargs.get(key)

if value is not None and _is_given(value):
set_data_normalized(span, attribute, value)

# Input attributes: Tools
explicit_instructions: "Union[Optional[str], Omit]" = kwargs.get("instructions")
messages: "Optional[Union[str, ResponseInputParam]]" = kwargs.get("input")

tools = kwargs.get("tools")
if tools is not None and _is_given(tools) and len(tools) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
)

model = kwargs.get("model")
if model is not None and _is_given(model):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)

def _set_responses_api_input_data(
span: "Span",
kwargs: "dict[str, Any]",
integration: "OpenAIIntegration",
) -> None:
explicit_instructions: "Union[Optional[str], Omit]" = kwargs.get("instructions")
messages: "Optional[Union[str, ResponseInputParam]]" = kwargs.get("input")
stream = kwargs.get("stream")
if stream is not None and _is_given(stream):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, stream)

max_tokens = kwargs.get("max_output_tokens")
if max_tokens is not None and _is_given(max_tokens):
span.set_data(SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

presence_penalty = kwargs.get("presence_penalty")
if presence_penalty is not None and _is_given(presence_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty
)

frequency_penalty = kwargs.get("frequency_penalty")
if frequency_penalty is not None and _is_given(frequency_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty
)

temperature = kwargs.get("temperature")
if temperature is not None and _is_given(temperature):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

top_p = kwargs.get("top_p")
if top_p is not None and _is_given(top_p):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

if not should_send_default_pii() or not integration.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
_commmon_set_input_data(span, kwargs)
return

if (
Expand All @@ -281,12 +287,10 @@ def _set_responses_api_input_data(
)

set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
_commmon_set_input_data(span, kwargs)
return

if messages is None:
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
_commmon_set_input_data(span, kwargs)
return

instructions_text_parts: "list[TextPart]" = []
Expand Down Expand Up @@ -319,7 +323,6 @@ def _set_responses_api_input_data(
)

set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
_commmon_set_input_data(span, kwargs)
return

non_system_messages = [
Expand All @@ -335,7 +338,6 @@ def _set_responses_api_input_data(
)

set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
_commmon_set_input_data(span, kwargs)


def _set_completions_api_input_data(
Expand All @@ -347,13 +349,50 @@ def _set_completions_api_input_data(
"messages"
)

tools = kwargs.get("tools")
if tools is not None and _is_given(tools) and len(tools) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
)

model = kwargs.get("model")
if model is not None and _is_given(model):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)

stream = kwargs.get("stream")
if stream is not None and _is_given(stream):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, stream)

max_tokens = kwargs.get("max_tokens")
if max_tokens is not None and _is_given(max_tokens):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

presence_penalty = kwargs.get("presence_penalty")
if presence_penalty is not None and _is_given(presence_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty
)

frequency_penalty = kwargs.get("frequency_penalty")
if frequency_penalty is not None and _is_given(frequency_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty
)

temperature = kwargs.get("temperature")
if temperature is not None and _is_given(temperature):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

top_p = kwargs.get("top_p")
if top_p is not None and _is_given(top_p):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

if (
not should_send_default_pii()
or not integration.include_prompts
or messages is None
):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_commmon_set_input_data(span, kwargs)
return

if isinstance(messages, str):
Expand All @@ -365,13 +404,11 @@ def _set_completions_api_input_data(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_commmon_set_input_data(span, kwargs)
return

# dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
if not isinstance(messages, Iterable) or isinstance(messages, dict):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_commmon_set_input_data(span, kwargs)
return

messages = list(messages)
Expand Down Expand Up @@ -399,7 +436,6 @@ def _set_completions_api_input_data(
)

set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
_commmon_set_input_data(span, kwargs)


def _set_embeddings_input_data(
Expand All @@ -411,19 +447,55 @@ def _set_embeddings_input_data(
"input"
)

tools = kwargs.get("tools")
if tools is not None and _is_given(tools) and len(tools) > 0:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
)

model = kwargs.get("model")
if model is not None and _is_given(model):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)

stream = kwargs.get("stream")
if stream is not None and _is_given(stream):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, stream)

max_tokens = kwargs.get("max_tokens")
if max_tokens is not None and _is_given(max_tokens):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

presence_penalty = kwargs.get("presence_penalty")
if presence_penalty is not None and _is_given(presence_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty
)

frequency_penalty = kwargs.get("frequency_penalty")
if frequency_penalty is not None and _is_given(frequency_penalty):
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty
)

temperature = kwargs.get("temperature")
if temperature is not None and _is_given(temperature):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

top_p = kwargs.get("top_p")
if top_p is not None and _is_given(top_p):
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

if (
not should_send_default_pii()
or not integration.include_prompts
or messages is None
):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

return

if isinstance(messages, str):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

normalized_messages = normalize_message_roles([messages]) # type: ignore
scope = sentry_sdk.get_current_scope()
Expand All @@ -440,7 +512,6 @@ def _set_embeddings_input_data(
# dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
if not isinstance(messages, Iterable) or isinstance(messages, dict):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)
return

messages = list(messages)
Expand All @@ -458,7 +529,6 @@ def _set_embeddings_input_data(
)

set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)


def _set_common_output_data(
Expand Down Expand Up @@ -551,6 +621,7 @@ def _new_chat_completion_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any
)
span.__enter__()

span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
_set_completions_api_input_data(span, kwargs, integration)

start_time = time.perf_counter()
Expand Down Expand Up @@ -945,6 +1016,7 @@ def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
name=f"embeddings {model}",
origin=OpenAIIntegration.origin,
) as span:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
_set_embeddings_input_data(span, kwargs, integration)

response = yield f, args, kwargs
Expand Down Expand Up @@ -1036,6 +1108,7 @@ def _new_responses_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "An
)
span.__enter__()

span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
_set_responses_api_input_data(span, kwargs, integration)

start_time = time.perf_counter()
Expand Down
Loading
Loading