Skip to content
1 change: 1 addition & 0 deletions sentry_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ class OP:
GEN_AI_CHAT = "gen_ai.chat"
GEN_AI_EMBEDDINGS = "gen_ai.embeddings"
GEN_AI_EXECUTE_TOOL = "gen_ai.execute_tool"
GEN_AI_GENERATE_TEXT = "gen_ai.generate_text"
GEN_AI_HANDOFF = "gen_ai.handoff"
GEN_AI_PIPELINE = "gen_ai.pipeline"
GEN_AI_INVOKE_AGENT = "gen_ai.invoke_agent"
Expand Down
176 changes: 135 additions & 41 deletions sentry_sdk/integrations/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from functools import wraps

from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA

from typing import Any, Iterable, Callable

import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
Expand All @@ -18,7 +17,7 @@
try:
import huggingface_hub.inference._client

from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
from huggingface_hub import ChatCompletionOutput, TextGenerationOutput
except ImportError:
raise DidNotEnable("Huggingface not installed")

Expand All @@ -34,9 +33,18 @@ def __init__(self, include_prompts=True):
@staticmethod
def setup_once():
# type: () -> None

# Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
huggingface_hub.inference._client.InferenceClient.text_generation = (
_wrap_text_generation(
huggingface_hub.inference._client.InferenceClient.text_generation
_wrap_huggingface_task(
huggingface_hub.inference._client.InferenceClient.text_generation,
OP.GEN_AI_GENERATE_TEXT,
)
)
huggingface_hub.inference._client.InferenceClient.chat_completion = (
_wrap_huggingface_task(
huggingface_hub.inference._client.InferenceClient.chat_completion,
OP.GEN_AI_CHAT,
)
)

Expand All @@ -51,65 +59,121 @@ def _capture_exception(exc):
sentry_sdk.capture_event(event, hint=hint)


def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
def _wrap_huggingface_task(f, op):
# type: (Callable[..., Any], str) -> Callable[..., Any]
@wraps(f)
def new_text_generation(*args, **kwargs):
def new_huggingface_task(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
if integration is None:
return f(*args, **kwargs)

if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif "messages" in kwargs:
prompt = kwargs["messages"]
elif len(args) >= 2:
kwargs["prompt"] = args[1]
prompt = kwargs["prompt"]
args = (args[0],) + args[2:]
else:
# invalid call, let it return error
# invalid call, dont instrument, let it return error
return f(*args, **kwargs)

model = kwargs.get("model")
streaming = kwargs.get("stream")
client = args[0]
model = client.model or kwargs.get("model") or ""
operation_name = op.split(".")[-1]

span = sentry_sdk.start_span(
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
name="Text Generation",
op=op,
name=f"{operation_name} {model}",
origin=HuggingfaceHubIntegration.origin,
)
span.__enter__()

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)

if model:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

# Input attributes
attribute_mapping = {
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
"top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
"top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
}
for attribute, span_attribute in attribute_mapping.items():
value = kwargs.get(attribute, None)
if value is not None:
if isinstance(value, (int, float, bool, str)):
span.set_data(span_attribute, value)
else:
set_data_normalized(span, span_attribute, value, unpack=False)

try:
res = f(*args, **kwargs)
except Exception as e:
span.set_status("error")
_capture_exception(e)
span.__exit__(None, None, None)
raise e from None

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
# Output attributes
if hasattr(res, "model"):
model = res.model
if model:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model)

if hasattr(res, "details") and res.details is not None:
finish_reason = getattr(res.details, "finish_reason", None)
if finish_reason:
span.set_data(
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reason
)

try:
tool_calls = res.choices[0].message.tool_calls
except Exception:
tool_calls = []

set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
if len(tool_calls) > 0:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
tool_calls,
unpack=False,
)

if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
)

if isinstance(res, str):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
[res],
)
if res:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
res,
)
span.__exit__(None, None, None)
return res

if isinstance(res, TextGenerationOutput):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
[res.generated_text],
)
if res.generated_text:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
res.generated_text,
)
if res.details is not None and res.details.generated_tokens > 0:
record_token_usage(
span,
Expand All @@ -118,16 +182,38 @@ def new_text_generation(*args, **kwargs):
span.__exit__(None, None, None)
return res

if isinstance(res, ChatCompletionOutput):
if should_send_default_pii() and integration.include_prompts:
text_response = "".join(
[
x.get("message", {}).get("content", None) or ""
for x in res.choices
]
)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
if hasattr(res, "usage") and res.usage is not None:
record_token_usage(
span,
input_tokens=res.usage.prompt_tokens,
output_tokens=res.usage.completion_tokens,
total_tokens=res.usage.total_tokens,
)
span.__exit__(None, None, None)
return res

if not isinstance(res, Iterable):
# we only know how to deal with strings and iterables, ignore
set_data_normalized(span, "unknown_response", True)
span.__exit__(None, None, None)
return res

if kwargs.get("details", False):
# res is Iterable[TextGenerationStreamOutput]

def new_details_iterator():
# type: () -> Iterable[ChatCompletionStreamOutput]
# type: () -> Iterable[Any]
with capture_internal_exceptions():
tokens_used = 0
data_buf: list[str] = []
Expand All @@ -144,9 +230,13 @@ def new_details_iterator():
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
text_response = "".join(data_buf)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
if tokens_used > 0:
record_token_usage(
span,
Expand All @@ -171,11 +261,15 @@ def new_iterator():
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
text_response = "".join(data_buf)
if text_response:
set_data_normalized(
span,
SPANDATA.GEN_AI_RESPONSE_TEXT,
text_response,
)
span.__exit__(None, None, None)

return new_iterator()

return new_text_generation
return new_huggingface_task
21 changes: 10 additions & 11 deletions tests/integrations/huggingface_hub/test_huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from huggingface_hub.errors import OverloadedError

from sentry_sdk import start_transaction
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations.huggingface_hub import HuggingfaceHubIntegration


Expand Down Expand Up @@ -68,14 +67,14 @@ def test_nonstreaming_chat_completion(
tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "ai.chat_completions.create.huggingface_hub"
assert span["op"] == "gen_ai.generate_text"

if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES]
assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
assert "hello" in span["data"]["gen_ai.request.messages"]
assert "the model response" in span["data"]["gen_ai.response.text"]
else:
assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
assert SPANDATA.AI_RESPONSES not in span["data"]
assert "gen_ai.request.messages" not in span["data"]
assert "gen_ai.response.text" not in span["data"]

if details_arg:
assert span["data"]["gen_ai.usage.total_tokens"] == 10
Expand Down Expand Up @@ -127,14 +126,14 @@ def test_streaming_chat_completion(
tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "ai.chat_completions.create.huggingface_hub"
assert span["op"] == "gen_ai.generate_text"

if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES]
assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
assert "hello" in span["data"]["gen_ai.request.messages"]
assert "the model response" in span["data"]["gen_ai.response.text"]
else:
assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
assert SPANDATA.AI_RESPONSES not in span["data"]
assert "gen_ai.request.messages" not in span["data"]
assert "gen_ai.response.text" not in span["data"]

if details_arg:
assert span["data"]["gen_ai.usage.total_tokens"] == 10
Expand Down