1717try :
1818 import huggingface_hub .inference ._client
1919
20- from huggingface_hub import ChatCompletionStreamOutput , TextGenerationOutput
20+ from huggingface_hub import ChatCompletionOutput , TextGenerationOutput
2121except ImportError :
2222 raise DidNotEnable ("Huggingface not installed" )
2323
@@ -40,6 +40,11 @@ def setup_once():
4040 huggingface_hub .inference ._client .InferenceClient .text_generation
4141 )
4242 )
43+ huggingface_hub .inference ._client .InferenceClient .chat_completion = (
44+ _wrap_text_generation (
45+ huggingface_hub .inference ._client .InferenceClient .chat_completion
46+ )
47+ )
4348
4449
4550def _capture_exception (exc ):
@@ -63,12 +68,14 @@ def new_text_generation(*args, **kwargs):
6368
6469 if "prompt" in kwargs :
6570 prompt = kwargs ["prompt" ]
71+ elif "messages" in kwargs :
72+ prompt = kwargs ["messages" ]
6673 elif len (args ) >= 2 :
6774 kwargs ["prompt" ] = args [1 ]
6875 prompt = kwargs ["prompt" ]
6976 args = (args [0 ],) + args [2 :]
7077 else :
71- # invalid call, let it return error
78+ # invalid call, dont instrument, let it return error
7279 return f (* args , ** kwargs )
7380
7481 client = args [0 ]
@@ -95,7 +102,9 @@ def new_text_generation(*args, **kwargs):
95102
96103 with capture_internal_exceptions ():
97104 if should_send_default_pii () and integration .include_prompts :
98- set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MESSAGES , prompt )
105+ set_data_normalized (
106+ span , SPANDATA .GEN_AI_REQUEST_MESSAGES , prompt , unpack = False
107+ )
99108
100109 span .set_data (SPANDATA .GEN_AI_RESPONSE_STREAMING , streaming )
101110
@@ -104,17 +113,20 @@ def new_text_generation(*args, **kwargs):
104113 set_data_normalized (
105114 span ,
106115 SPANDATA .GEN_AI_RESPONSE_TEXT ,
107- [ res ] ,
116+ res ,
108117 )
109118 span .__exit__ (None , None , None )
110119 return res
111120
112121 if isinstance (res , TextGenerationOutput ):
113122 if should_send_default_pii () and integration .include_prompts :
123+ import ipdb
124+
125+ ipdb .set_trace ()
114126 set_data_normalized (
115127 span ,
116128 SPANDATA .GEN_AI_RESPONSE_TEXT ,
117- [ res .generated_text ] ,
129+ res .generated_text ,
118130 )
119131 if res .details is not None and res .details .generated_tokens > 0 :
120132 record_token_usage (
@@ -124,15 +136,35 @@ def new_text_generation(*args, **kwargs):
124136 span .__exit__ (None , None , None )
125137 return res
126138
139+ if isinstance (res , ChatCompletionOutput ):
140+ if should_send_default_pii () and integration .include_prompts :
141+ text_response = "" .join (
142+ [x .get ("message" , {}).get ("content" ) for x in res .choices ]
143+ )
144+ set_data_normalized (
145+ span ,
146+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
147+ text_response ,
148+ )
149+ if hasattr (res , "usage" ) and res .usage is not None :
150+ record_token_usage (
151+ span ,
152+ input_tokens = res .usage .prompt_tokens ,
153+ output_tokens = res .usage .completion_tokens ,
154+ total_tokens = res .usage .total_tokens ,
155+ )
156+ span .__exit__ (None , None , None )
157+ return res
158+
127159 if not isinstance (res , Iterable ):
128160 # we only know how to deal with strings and iterables, ignore
129161 span .__exit__ (None , None , None )
130162 return res
131163
132164 if kwargs .get ("details" , False ):
133- # res is Iterable[TextGenerationStreamOutput]
165+
134166 def new_details_iterator ():
135- # type: () -> Iterable[ChatCompletionStreamOutput ]
167+ # type: () -> Iterable[Any ]
136168 with capture_internal_exceptions ():
137169 tokens_used = 0
138170 data_buf : list [str ] = []
@@ -150,7 +182,9 @@ def new_details_iterator():
150182 and integration .include_prompts
151183 ):
152184 set_data_normalized (
153- span , SPANDATA .GEN_AI_RESPONSE_TEXT , "" .join (data_buf )
185+ span ,
186+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
187+ "" .join (data_buf ),
154188 )
155189 if tokens_used > 0 :
156190 record_token_usage (
@@ -177,7 +211,9 @@ def new_iterator():
177211 and integration .include_prompts
178212 ):
179213 set_data_normalized (
180- span , SPANDATA .GEN_AI_RESPONSE_TEXT , "" .join (data_buf )
214+ span ,
215+ SPANDATA .GEN_AI_RESPONSE_TEXT ,
216+ "" .join (data_buf ),
181217 )
182218 span .__exit__ (None , None , None )
183219
0 commit comments