1414from attr import dataclass
1515from langchain .schema import BaseMessage , HumanMessage
1616from openai import OpenAI
17+ from openai .types .responses import ResponseFunctionToolCall
1718from PIL import Image
1819
19- from agentlab .agents import agent_utils
2020from agentlab .analyze import inspect_results
21+ from agentlab .analyze .overlay_utils import annotate_action
2122from agentlab .experiments .exp_utils import RESULTS_DIR
2223from agentlab .experiments .loop import ExpResult , StepInfo
2324from agentlab .experiments .study import get_most_recent_study
@@ -351,7 +352,7 @@ def run_gradio(results_dir: Path):
351352 pruned_html_code = gr .Code (language = "html" , ** code_args )
352353
353354 with gr .Tab ("AXTree" ) as tab_axtree :
354- axtree_code = gr .Code ( language = None , ** code_args )
355+ axtree_code = gr .Markdown ( )
355356
356357 with gr .Tab ("Chat Messages" ) as tab_chat :
357358 chat_messages = gr .Markdown ()
@@ -536,38 +537,45 @@ def wrapper(*args, **kwargs):
536537
537538def update_screenshot (som_or_not : str ):
538539 global info
539- action = info .exp_result .steps_info [info .step ].action
540- return agent_utils .tag_screenshot_with_action (
541- get_screenshot (info , som_or_not = som_or_not ), action
542- )
540+ img , action_str = get_screenshot (info , som_or_not = som_or_not , annotate = True )
541+ return img
543542
544543
545- def get_screenshot (info : Info , step : int = None , som_or_not : str = "Raw Screenshots" ):
544+ def get_screenshot (
545+ info : Info , step : int = None , som_or_not : str = "Raw Screenshots" , annotate : bool = False
546+ ):
546547 if step is None :
547548 step = info .step
549+ step_info = info .exp_result .steps_info [step ]
548550 try :
549551 is_som = som_or_not == "SOM Screenshots"
550- return info .exp_result .get_screenshot (step , som = is_som )
552+ img = info .exp_result .get_screenshot (step , som = is_som )
553+ if annotate :
554+ action_str = step_info .action
555+ properties = step_info .obs .get ("extra_element_properties" , None )
556+ action_colored = annotate_action (img , action_string = action_str , properties = properties )
557+ else :
558+ action_colored = None
559+ return img , action_colored
551560 except FileNotFoundError :
552- return None
561+ return None , None
553562
554563
555564def update_screenshot_pair (som_or_not : str ):
556565 global info
557- s1 = get_screenshot (info , info .step , som_or_not )
558- s2 = get_screenshot (info , info .step + 1 , som_or_not )
559-
560- if s1 is not None :
561- s1 = agent_utils .tag_screenshot_with_action (
562- s1 , info .exp_result .steps_info [info .step ].action
563- )
566+ s1 , action_str = get_screenshot (info , info .step , som_or_not , annotate = True )
567+ s2 , action_str = get_screenshot (info , info .step + 1 , som_or_not )
564568 return s1 , s2
565569
566570
567571def update_screenshot_gallery (som_or_not : str ):
568572 global info
569- screenshots = info .exp_result .get_screenshots (som = som_or_not == "SOM Screenshots" )
573+ max_steps = len (info .exp_result .steps_info )
574+
575+ screenshots = [get_screenshot (info , step = i , som_or_not = som_or_not )[0 ] for i in range (max_steps )]
576+
570577 screenshots_and_label = [(s , f"Step { i } " ) for i , s in enumerate (screenshots )]
578+
571579 gallery = gr .Gallery (
572580 value = screenshots_and_label ,
573581 columns = 2 ,
@@ -595,7 +603,8 @@ def update_pruned_html():
595603
596604
597605def update_axtree ():
598- return get_obs (key = "axtree_txt" , default = "No AXTree" )
606+ obs = get_obs (key = "axtree_txt" , default = "No AXTree" )
607+ return f"```\n { obs } \n ```"
599608
600609
601610def dict_to_markdown (d : dict ):
@@ -645,7 +654,7 @@ def dict_msg_to_markdown(d: dict):
645654 case "text" :
646655 parts .append (f"\n ```\n { item ['text' ]} \n ```\n " )
647656 case "tool_use" :
648- tool_use = f"Tool Use: { item [' name' ] } { item [' input' ] } (id = { item ['id' ] } )"
657+ tool_use = _format_tool_call ( item [" name" ], item [" input" ], item ["call_id" ])
649658 parts .append (f"\n ```\n { tool_use } \n ```\n " )
650659 case _:
651660 parts .append (f"\n ```\n { str (item )} \n ```\n " )
@@ -655,27 +664,40 @@ def dict_msg_to_markdown(d: dict):
655664 return markdown
656665
657666
667+ def _format_tool_call (name : str , input : str , call_id : str ):
668+ """
669+ Format a tool call to markdown.
670+ """
671+ return f"Tool Call: { name } `{ input } ` (call_id: { call_id } )"
672+
673+
674+ def format_chat_message (message : BaseMessage | MessageBuilder | dict ):
675+ """
676+ Format a message to markdown.
677+ """
678+ if isinstance (message , BaseMessage ):
679+ return message .content
680+ elif isinstance (message , MessageBuilder ):
681+ return message .to_markdown ()
682+ elif isinstance (message , dict ):
683+ return dict_msg_to_markdown (message )
684+ elif isinstance (message , ResponseFunctionToolCall ): # type: ignore[return]
685+ too_use_str = _format_tool_call (message .name , message .arguments , message .call_id )
686+ return f"### Tool Use\n ```\n { too_use_str } \n ```\n "
687+ else :
688+ return str (message )
689+
690+
658691def update_chat_messages ():
659692 global info
660693 agent_info = info .exp_result .steps_info [info .step ].agent_info
661694 chat_messages = agent_info .get ("chat_messages" , ["No Chat Messages" ])
662695 if isinstance (chat_messages , Discussion ):
663696 return chat_messages .to_markdown ()
664697
665- if isinstance (chat_messages , list ) and isinstance (chat_messages [0 ], MessageBuilder ):
666- chat_messages = [
667- m .to_markdown () if isinstance (m , MessageBuilder ) else dict_msg_to_markdown (m )
668- for m in chat_messages
669- ]
698+ if isinstance (chat_messages , list ):
699+ chat_messages = [format_chat_message (m ) for m in chat_messages ]
670700 return "\n \n " .join (chat_messages )
671- messages = [] # TODO(ThibaultLSDC) remove this at some point
672- for i , m in enumerate (chat_messages ):
673- if isinstance (m , BaseMessage ): # TODO remove once langchain is deprecated
674- m = m .content
675- elif isinstance (m , dict ):
676- m = m .get ("content" , "No Content" )
677- messages .append (f"""# Message { i } \n ```\n { m } \n ```\n \n """ )
678- return "\n " .join (messages )
679701
680702
681703def update_task_error ():
@@ -722,8 +744,8 @@ def update_agent_info_html():
722744 global info
723745 # screenshots from current and next step
724746 try :
725- s1 = get_screenshot (info , info .step , False )
726- s2 = get_screenshot (info , info .step + 1 , False )
747+ s1 , action_str = get_screenshot (info , info .step , False )
748+ s2 , action_str = get_screenshot (info , info .step + 1 , False )
727749 agent_info = info .exp_result .steps_info [info .step ].agent_info
728750 page = agent_info .get ("html_page" , ["No Agent Info" ])
729751 if page is None :
@@ -854,6 +876,8 @@ def get_episode_info(info: Info):
854876
855877def get_action_info (info : Info ):
856878 steps_info = info .exp_result .steps_info
879+ img , action_str = get_screenshot (info , step = info .step , annotate = True ) # to update click_mapper
880+
857881 if len (steps_info ) == 0 :
858882 return "No steps were taken"
859883 if len (steps_info ) <= info .step :
@@ -863,7 +887,7 @@ def get_action_info(info: Info):
863887 action_info = f"""\
864888 **Action:**
865889
866- { code ( step_info . action ) }
890+ { action_str }
867891"""
868892 think = step_info .agent_info .get ("think" , None )
869893 if think is not None :
@@ -1084,16 +1108,19 @@ def get_directory_contents(results_dir: Path):
10841108 most_recent_summary = max (summary_files , key = os .path .getctime )
10851109 summary_df = pd .read_csv (most_recent_summary )
10861110
1111+ if len (summary_df ) == 0 or summary_df ["avg_reward" ].isna ().all ():
1112+ continue # skip if all avg_reward are NaN
1113+
10871114 # get row with max avg_reward
1088- max_reward_row = summary_df .loc [summary_df ["avg_reward" ].idxmax ()]
1115+ max_reward_row = summary_df .loc [summary_df ["avg_reward" ].idxmax (skipna = True )]
10891116 reward = max_reward_row ["avg_reward" ] * 100
10901117 completed = max_reward_row ["n_completed" ]
10911118 n_err = max_reward_row ["n_err" ]
10921119 exp_description += (
10931120 f" - avg-reward: { reward :.1f} % - completed: { completed } - errors: { n_err } "
10941121 )
10951122 except Exception as e :
1096- print (f"Error while reading summary file: { e } " )
1123+ print (f"Error while reading summary file { most_recent_summary } : { e } " )
10971124
10981125 exp_descriptions .append (exp_description )
10991126
@@ -1219,7 +1246,6 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12191246 horizontalalignment = "left" ,
12201247 rotation = 0 ,
12211248 clip_on = True ,
1222- antialiased = True ,
12231249 fontweight = 1000 ,
12241250 backgroundcolor = colors [12 ],
12251251 )
0 commit comments