1- import json
21import time
32from typing import Any , Dict , List , Optional , Sequence , Tuple , TypedDict , Union
43from uuid import UUID
@@ -170,7 +169,31 @@ def _convert_message_dict(
170169 if function_call :
171170 msg ["function_call" ] = function_call
172171 else :
173- msg ["content" ] = kwargs .get ("content" , "" )
172+ content = kwargs .get ("content" )
173+ if isinstance (content , list ):
174+ tool_calls = []
175+ content_parts = []
176+ for item in content :
177+ if item .get ("type" ) == "tool_use" :
178+ tool_calls .append (
179+ {
180+ "id" : item .get ("id" ),
181+ "type" : "function" ,
182+ "function" : {
183+ "name" : item .get ("name" ),
184+ "arguments" : item .get ("input" ),
185+ },
186+ }
187+ )
188+ elif item .get ("type" ) == "text" :
189+ content_parts .append ({"type" : "text" , "text" : item .get ("text" )})
190+
191+ if tool_calls :
192+ msg ["tool_calls" ] = tool_calls
193+ if content_parts :
194+ msg ["content" ] = content_parts # type: ignore
195+ else :
196+ msg ["content" ] = content # type: ignore
174197
175198 return msg
176199
@@ -182,6 +205,7 @@ def _convert_message(
182205 return self ._convert_message_dict (
183206 message ,
184207 )
208+
185209 function_call = message .additional_kwargs .get ("function_call" )
186210
187211 msg = GenerationMessage (
@@ -199,7 +223,32 @@ def _convert_message(
199223 if function_call :
200224 msg ["function_call" ] = function_call
201225 else :
202- msg ["content" ] = message .content # type: ignore
226+ if isinstance (message .content , list ):
227+ tool_calls = []
228+ content_parts = []
229+ for item in message .content :
230+ if isinstance (item , str ):
231+ continue
232+ if item .get ("type" ) == "tool_use" :
233+ tool_calls .append (
234+ {
235+ "id" : item .get ("id" ),
236+ "type" : "function" ,
237+ "function" : {
238+ "name" : item .get ("name" ),
239+ "arguments" : item .get ("input" ),
240+ },
241+ }
242+ )
243+ elif item .get ("type" ) == "text" :
244+ content_parts .append ({"type" : "text" , "text" : item .get ("text" )})
245+
246+ if tool_calls :
247+ msg ["tool_calls" ] = tool_calls
248+ if content_parts :
249+ msg ["content" ] = content_parts # type: ignore
250+ else :
251+ msg ["content" ] = message .content # type: ignore
203252
204253 return msg
205254
@@ -236,7 +285,12 @@ def _build_llm_settings(
236285 if "functions" in settings :
237286 tools = [{"type" : "function" , "function" : f } for f in settings ["functions" ]]
238287 if "tools" in settings :
239- tools = settings ["tools" ]
288+ tools = [
289+ {"type" : "function" , "function" : t }
290+ if t .get ("type" ) != "function"
291+ else t
292+ for t in settings ["tools" ]
293+ ]
240294 return provider , model , tools , settings
241295
242296
@@ -492,11 +546,12 @@ async def _start_trace(self, run: Run) -> None:
492546 parent_id = parent_id ,
493547 )
494548 step .start = utc_now ()
495- step .input , language = process_content (run .inputs )
496- if language is not None :
497- if step .metadata is None :
498- step .metadata = {}
499- step .metadata ["language" ] = language
549+ if step .metadata is None :
550+ step .metadata = {}
551+ if step_type != "llm" :
552+ step .input , language = process_content (run .inputs )
553+ if language is not None :
554+ step .metadata ["language" ] = language
500555
501556 step .tags = run .tags
502557 self .steps [str (run .id )] = step
@@ -560,9 +615,6 @@ async def _on_run_update(self, run: Run) -> None:
560615 break
561616
562617 current_step .language = "json"
563- current_step .output = json .dumps (
564- message_completion , indent = 4 , ensure_ascii = False
565- )
566618 else :
567619 completion_start = self .completion_generations [str (run .id )]
568620 completion = generation .get ("text" , "" )
@@ -601,13 +653,14 @@ async def _on_run_update(self, run: Run) -> None:
601653 output = outputs .get (output_keys [0 ], outputs )
602654
603655 if current_step :
604- current_step .output = (
605- output [0 ]
606- if isinstance (output , Sequence )
607- and not isinstance (output , str )
608- and len (output )
609- else output
610- )
656+ if current_step .type != "llm" :
657+ current_step .output = (
658+ output [0 ]
659+ if isinstance (output , Sequence )
660+ and not isinstance (output , str )
661+ and len (output )
662+ else output
663+ )
611664 current_step .end = utc_now ()
612665 await current_step .update ()
613666
0 commit comments