Skip to content

Commit 05ef7e3

Browse files
committed
First version of supporting chat-completion
1 parent 54164cd commit 05ef7e3

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

sentry_sdk/integrations/huggingface_hub.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
try:
1818
import huggingface_hub.inference._client
1919

20-
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
20+
from huggingface_hub import ChatCompletionOutput, TextGenerationOutput
2121
except ImportError:
2222
raise DidNotEnable("Huggingface not installed")
2323

@@ -40,6 +40,11 @@ def setup_once():
4040
huggingface_hub.inference._client.InferenceClient.text_generation
4141
)
4242
)
43+
huggingface_hub.inference._client.InferenceClient.chat_completion = (
44+
_wrap_text_generation(
45+
huggingface_hub.inference._client.InferenceClient.chat_completion
46+
)
47+
)
4348

4449

4550
def _capture_exception(exc):
@@ -63,12 +68,14 @@ def new_text_generation(*args, **kwargs):
6368

6469
if "prompt" in kwargs:
6570
prompt = kwargs["prompt"]
71+
elif "messages" in kwargs:
72+
prompt = kwargs["messages"]
6673
elif len(args) >= 2:
6774
kwargs["prompt"] = args[1]
6875
prompt = kwargs["prompt"]
6976
args = (args[0],) + args[2:]
7077
else:
71-
# invalid call, let it return error
78+
# invalid call, dont instrument, let it return error
7279
return f(*args, **kwargs)
7380

7481
client = args[0]
@@ -95,7 +102,9 @@ def new_text_generation(*args, **kwargs):
95102

96103
with capture_internal_exceptions():
97104
if should_send_default_pii() and integration.include_prompts:
98-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt)
105+
set_data_normalized(
106+
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
107+
)
99108

100109
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, streaming)
101110

@@ -104,17 +113,20 @@ def new_text_generation(*args, **kwargs):
104113
set_data_normalized(
105114
span,
106115
SPANDATA.GEN_AI_RESPONSE_TEXT,
107-
[res],
116+
res,
108117
)
109118
span.__exit__(None, None, None)
110119
return res
111120

112121
if isinstance(res, TextGenerationOutput):
113122
if should_send_default_pii() and integration.include_prompts:
123+
import ipdb
124+
125+
ipdb.set_trace()
114126
set_data_normalized(
115127
span,
116128
SPANDATA.GEN_AI_RESPONSE_TEXT,
117-
[res.generated_text],
129+
res.generated_text,
118130
)
119131
if res.details is not None and res.details.generated_tokens > 0:
120132
record_token_usage(
@@ -124,15 +136,35 @@ def new_text_generation(*args, **kwargs):
124136
span.__exit__(None, None, None)
125137
return res
126138

139+
if isinstance(res, ChatCompletionOutput):
140+
if should_send_default_pii() and integration.include_prompts:
141+
text_response = "".join(
142+
[x.get("message", {}).get("content") for x in res.choices]
143+
)
144+
set_data_normalized(
145+
span,
146+
SPANDATA.GEN_AI_RESPONSE_TEXT,
147+
text_response,
148+
)
149+
if hasattr(res, "usage") and res.usage is not None:
150+
record_token_usage(
151+
span,
152+
input_tokens=res.usage.prompt_tokens,
153+
output_tokens=res.usage.completion_tokens,
154+
total_tokens=res.usage.total_tokens,
155+
)
156+
span.__exit__(None, None, None)
157+
return res
158+
127159
if not isinstance(res, Iterable):
128160
# we only know how to deal with strings and iterables, ignore
129161
span.__exit__(None, None, None)
130162
return res
131163

132164
if kwargs.get("details", False):
133-
# res is Iterable[TextGenerationStreamOutput]
165+
134166
def new_details_iterator():
135-
# type: () -> Iterable[ChatCompletionStreamOutput]
167+
# type: () -> Iterable[Any]
136168
with capture_internal_exceptions():
137169
tokens_used = 0
138170
data_buf: list[str] = []
@@ -150,7 +182,9 @@ def new_details_iterator():
150182
and integration.include_prompts
151183
):
152184
set_data_normalized(
153-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, "".join(data_buf)
185+
span,
186+
SPANDATA.GEN_AI_RESPONSE_TEXT,
187+
"".join(data_buf),
154188
)
155189
if tokens_used > 0:
156190
record_token_usage(
@@ -177,7 +211,9 @@ def new_iterator():
177211
and integration.include_prompts
178212
):
179213
set_data_normalized(
180-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, "".join(data_buf)
214+
span,
215+
SPANDATA.GEN_AI_RESPONSE_TEXT,
216+
"".join(data_buf),
181217
)
182218
span.__exit__(None, None, None)
183219

0 commit comments

Comments
 (0)