17
17
try :
18
18
import huggingface_hub .inference ._client
19
19
20
- from huggingface_hub import ChatCompletionStreamOutput , TextGenerationOutput
20
+ from huggingface_hub import ChatCompletionOutput , TextGenerationOutput
21
21
except ImportError :
22
22
raise DidNotEnable ("Huggingface not installed" )
23
23
@@ -40,6 +40,11 @@ def setup_once():
40
40
huggingface_hub .inference ._client .InferenceClient .text_generation
41
41
)
42
42
)
43
+ huggingface_hub .inference ._client .InferenceClient .chat_completion = (
44
+ _wrap_text_generation (
45
+ huggingface_hub .inference ._client .InferenceClient .chat_completion
46
+ )
47
+ )
43
48
44
49
45
50
def _capture_exception (exc ):
@@ -63,12 +68,14 @@ def new_text_generation(*args, **kwargs):
63
68
64
69
if "prompt" in kwargs :
65
70
prompt = kwargs ["prompt" ]
71
+ elif "messages" in kwargs :
72
+ prompt = kwargs ["messages" ]
66
73
elif len (args ) >= 2 :
67
74
kwargs ["prompt" ] = args [1 ]
68
75
prompt = kwargs ["prompt" ]
69
76
args = (args [0 ],) + args [2 :]
70
77
else :
71
- # invalid call, let it return error
78
+ # invalid call, dont instrument, let it return error
72
79
return f (* args , ** kwargs )
73
80
74
81
client = args [0 ]
@@ -95,7 +102,9 @@ def new_text_generation(*args, **kwargs):
95
102
96
103
with capture_internal_exceptions ():
97
104
if should_send_default_pii () and integration .include_prompts :
98
- set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MESSAGES , prompt )
105
+ set_data_normalized (
106
+ span , SPANDATA .GEN_AI_REQUEST_MESSAGES , prompt , unpack = False
107
+ )
99
108
100
109
span .set_data (SPANDATA .GEN_AI_RESPONSE_STREAMING , streaming )
101
110
@@ -104,17 +113,20 @@ def new_text_generation(*args, **kwargs):
104
113
set_data_normalized (
105
114
span ,
106
115
SPANDATA .GEN_AI_RESPONSE_TEXT ,
107
- [ res ] ,
116
+ res ,
108
117
)
109
118
span .__exit__ (None , None , None )
110
119
return res
111
120
112
121
if isinstance (res , TextGenerationOutput ):
113
122
if should_send_default_pii () and integration .include_prompts :
123
+ import ipdb
124
+
125
+ ipdb .set_trace ()
114
126
set_data_normalized (
115
127
span ,
116
128
SPANDATA .GEN_AI_RESPONSE_TEXT ,
117
- [ res .generated_text ] ,
129
+ res .generated_text ,
118
130
)
119
131
if res .details is not None and res .details .generated_tokens > 0 :
120
132
record_token_usage (
@@ -124,15 +136,35 @@ def new_text_generation(*args, **kwargs):
124
136
span .__exit__ (None , None , None )
125
137
return res
126
138
139
+ if isinstance (res , ChatCompletionOutput ):
140
+ if should_send_default_pii () and integration .include_prompts :
141
+ text_response = "" .join (
142
+ [x .get ("message" , {}).get ("content" ) for x in res .choices ]
143
+ )
144
+ set_data_normalized (
145
+ span ,
146
+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
147
+ text_response ,
148
+ )
149
+ if hasattr (res , "usage" ) and res .usage is not None :
150
+ record_token_usage (
151
+ span ,
152
+ input_tokens = res .usage .prompt_tokens ,
153
+ output_tokens = res .usage .completion_tokens ,
154
+ total_tokens = res .usage .total_tokens ,
155
+ )
156
+ span .__exit__ (None , None , None )
157
+ return res
158
+
127
159
if not isinstance (res , Iterable ):
128
160
# we only know how to deal with strings and iterables, ignore
129
161
span .__exit__ (None , None , None )
130
162
return res
131
163
132
164
if kwargs .get ("details" , False ):
133
- # res is Iterable[TextGenerationStreamOutput]
165
+
134
166
def new_details_iterator ():
135
- # type: () -> Iterable[ChatCompletionStreamOutput ]
167
+ # type: () -> Iterable[Any ]
136
168
with capture_internal_exceptions ():
137
169
tokens_used = 0
138
170
data_buf : list [str ] = []
@@ -150,7 +182,9 @@ def new_details_iterator():
150
182
and integration .include_prompts
151
183
):
152
184
set_data_normalized (
153
- span , SPANDATA .GEN_AI_RESPONSE_TEXT , "" .join (data_buf )
185
+ span ,
186
+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
187
+ "" .join (data_buf ),
154
188
)
155
189
if tokens_used > 0 :
156
190
record_token_usage (
@@ -177,7 +211,9 @@ def new_iterator():
177
211
and integration .include_prompts
178
212
):
179
213
set_data_normalized (
180
- span , SPANDATA .GEN_AI_RESPONSE_TEXT , "" .join (data_buf )
214
+ span ,
215
+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
216
+ "" .join (data_buf ),
181
217
)
182
218
span .__exit__ (None , None , None )
183
219
0 commit comments