diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 3ed8efd506..ea9878ea03 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -794,6 +794,7 @@ class OP: GEN_AI_CHAT = "gen_ai.chat" GEN_AI_EMBEDDINGS = "gen_ai.embeddings" GEN_AI_EXECUTE_TOOL = "gen_ai.execute_tool" + GEN_AI_GENERATE_TEXT = "gen_ai.generate_text" GEN_AI_HANDOFF = "gen_ai.handoff" GEN_AI_PIPELINE = "gen_ai.pipeline" GEN_AI_INVOKE_AGENT = "gen_ai.invoke_agent" diff --git a/sentry_sdk/integrations/huggingface_hub.py b/sentry_sdk/integrations/huggingface_hub.py index 2dfcb5925a..9e7f6ff2d5 100644 --- a/sentry_sdk/integrations/huggingface_hub.py +++ b/sentry_sdk/integrations/huggingface_hub.py @@ -1,15 +1,14 @@ from functools import wraps -from sentry_sdk import consts -from sentry_sdk.ai.monitoring import record_token_usage -from sentry_sdk.ai.utils import set_data_normalized -from sentry_sdk.consts import SPANDATA from typing import Any, Iterable, Callable import sentry_sdk -from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.ai.utils import set_data_normalized +from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.scope import should_send_default_pii from sentry_sdk.utils import ( capture_internal_exceptions, event_from_exception, @@ -18,7 +17,7 @@ try: import huggingface_hub.inference._client - from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput + from huggingface_hub import ChatCompletionOutput, TextGenerationOutput except ImportError: raise DidNotEnable("Huggingface not installed") @@ -34,9 +33,18 @@ def __init__(self, include_prompts=True): @staticmethod def setup_once(): # type: () -> None + + # Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks huggingface_hub.inference._client.InferenceClient.text_generation = ( - _wrap_text_generation( - huggingface_hub.inference._client.InferenceClient.text_generation + _wrap_huggingface_task( + huggingface_hub.inference._client.InferenceClient.text_generation, + OP.GEN_AI_GENERATE_TEXT, + ) + ) + huggingface_hub.inference._client.InferenceClient.chat_completion = ( + _wrap_huggingface_task( + huggingface_hub.inference._client.InferenceClient.chat_completion, + OP.GEN_AI_CHAT, ) ) @@ -51,10 +59,10 @@ def _capture_exception(exc): sentry_sdk.capture_event(event, hint=hint) -def _wrap_text_generation(f): - # type: (Callable[..., Any]) -> Callable[..., Any] +def _wrap_huggingface_task(f, op): + # type: (Callable[..., Any], str) -> Callable[..., Any] @wraps(f) - def new_text_generation(*args, **kwargs): + def new_huggingface_task(*args, **kwargs): # type: (*Any, **Any) -> Any integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration) if integration is None: @@ -62,54 +70,110 @@ def new_text_generation(*args, **kwargs): if "prompt" in kwargs: prompt = kwargs["prompt"] + elif "messages" in kwargs: + prompt = kwargs["messages"] elif len(args) >= 2: kwargs["prompt"] = args[1] prompt = kwargs["prompt"] args = (args[0],) + args[2:] else: - # invalid call, let it return error + # invalid call, dont instrument, let it return error return f(*args, **kwargs) - model = kwargs.get("model") - streaming = kwargs.get("stream") + client = args[0] + model = client.model or kwargs.get("model") or "" + operation_name = op.split(".")[-1] span = sentry_sdk.start_span( - op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE, - name="Text Generation", + op=op, + name=f"{operation_name} {model}", origin=HuggingfaceHubIntegration.origin, ) span.__enter__() + + span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name) + + if model: + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model) + + # Input attributes + attribute_mapping = { + "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, + "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, + "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, + "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, + "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE, + "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P, + "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K, + "stream": SPANDATA.GEN_AI_RESPONSE_STREAMING, + } + for attribute, span_attribute in attribute_mapping.items(): + value = kwargs.get(attribute, None) + if value is not None: + if isinstance(value, (int, float, bool, str)): + span.set_data(span_attribute, value) + else: + set_data_normalized(span, span_attribute, value, unpack=False) + try: res = f(*args, **kwargs) except Exception as e: + span.set_status("error") _capture_exception(e) span.__exit__(None, None, None) raise e from None with capture_internal_exceptions(): - if should_send_default_pii() and integration.include_prompts: - set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt) + # Output attributes + if hasattr(res, "model"): + model = res.model + if model: + span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model) + + if hasattr(res, "details") and res.details is not None: + finish_reason = getattr(res.details, "finish_reason", None) + if finish_reason: + span.set_data( + SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reason + ) + + try: + tool_calls = res.choices[0].message.tool_calls + except Exception: + tool_calls = [] - set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) - set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) + if len(tool_calls) > 0: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, + tool_calls, + unpack=False, + ) + + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False + ) if isinstance(res, str): if should_send_default_pii() and integration.include_prompts: - set_data_normalized( - span, - SPANDATA.AI_RESPONSES, - [res], - ) + if res: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + res, + ) span.__exit__(None, None, None) return res if isinstance(res, TextGenerationOutput): if should_send_default_pii() and integration.include_prompts: - set_data_normalized( - span, - SPANDATA.AI_RESPONSES, - [res.generated_text], - ) + if res.generated_text: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + res.generated_text, + ) if res.details is not None and res.details.generated_tokens > 0: record_token_usage( span, @@ -118,16 +182,38 @@ def new_text_generation(*args, **kwargs): span.__exit__(None, None, None) return res + if isinstance(res, ChatCompletionOutput): + if should_send_default_pii() and integration.include_prompts: + text_response = "".join( + [ + x.get("message", {}).get("content", None) or "" + for x in res.choices + ] + ) + if text_response: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + text_response, + ) + if hasattr(res, "usage") and res.usage is not None: + record_token_usage( + span, + input_tokens=res.usage.prompt_tokens, + output_tokens=res.usage.completion_tokens, + total_tokens=res.usage.total_tokens, + ) + span.__exit__(None, None, None) + return res + if not isinstance(res, Iterable): - # we only know how to deal with strings and iterables, ignore - set_data_normalized(span, "unknown_response", True) span.__exit__(None, None, None) return res if kwargs.get("details", False): - # res is Iterable[TextGenerationStreamOutput] + def new_details_iterator(): - # type: () -> Iterable[ChatCompletionStreamOutput] + # type: () -> Iterable[Any] with capture_internal_exceptions(): tokens_used = 0 data_buf: list[str] = [] @@ -144,9 +230,13 @@ def new_details_iterator(): and should_send_default_pii() and integration.include_prompts ): - set_data_normalized( - span, SPANDATA.AI_RESPONSES, "".join(data_buf) - ) + text_response = "".join(data_buf) + if text_response: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + text_response, + ) if tokens_used > 0: record_token_usage( span, @@ -171,11 +261,15 @@ def new_iterator(): and should_send_default_pii() and integration.include_prompts ): - set_data_normalized( - span, SPANDATA.AI_RESPONSES, "".join(data_buf) - ) + text_response = "".join(data_buf) + if text_response: + set_data_normalized( + span, + SPANDATA.GEN_AI_RESPONSE_TEXT, + text_response, + ) span.__exit__(None, None, None) return new_iterator() - return new_text_generation + return new_huggingface_task diff --git a/tests/integrations/huggingface_hub/test_huggingface_hub.py b/tests/integrations/huggingface_hub/test_huggingface_hub.py index df0c6c6d76..a9dc450168 100644 --- a/tests/integrations/huggingface_hub/test_huggingface_hub.py +++ b/tests/integrations/huggingface_hub/test_huggingface_hub.py @@ -8,7 +8,6 @@ from huggingface_hub.errors import OverloadedError from sentry_sdk import start_transaction -from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations.huggingface_hub import HuggingfaceHubIntegration @@ -68,14 +67,14 @@ def test_nonstreaming_chat_completion( tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.huggingface_hub" + assert span["op"] == "gen_ai.generate_text" if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES] - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"]["gen_ai.request.messages"] + assert "the model response" in span["data"]["gen_ai.response.text"] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert "gen_ai.request.messages" not in span["data"] + assert "gen_ai.response.text" not in span["data"] if details_arg: assert span["data"]["gen_ai.usage.total_tokens"] == 10 @@ -127,14 +126,14 @@ def test_streaming_chat_completion( tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.huggingface_hub" + assert span["op"] == "gen_ai.generate_text" if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES] - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"]["gen_ai.request.messages"] + assert "the model response" in span["data"]["gen_ai.response.text"] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert "gen_ai.request.messages" not in span["data"] + assert "gen_ai.response.text" not in span["data"] if details_arg: assert span["data"]["gen_ai.usage.total_tokens"] == 10